In [1]:
import sys, os
import argparse

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "./../../../../")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "./../../../")))

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "")))

In [2]:
from fedml_api.data_preprocessing.cifar10.data_loader import load_partition_data_cifar10
from fedml_api.standalone.fedavg.my_model_trainer_classification import MyModelTrainer as MyModelTrainerCLS
from fedml_api.model.contrastive_cv.resnet_with_embedding import Resnet56
from CovaMNet import CovaMResnet56

import torch
from torch import nn
from collections import OrderedDict
import torch.nn.functional as F

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
import random
import pickle

In [3]:
dataset = 'cifar10'
data_dir = "./../../../data/cifar10"
# partition_method = 'hetero'
partition_method = 'homo'
partition_alpha = 0.5
client_num_in_total = 3
batch_size = 100
total_epochs = 500

save_model_path = 'model/cs_{0}_{1}_client_{2}_better_orcal_no_bnneck_resnet_{3}.pt'

device = 'cuda:3'

with open(f'dataset_{partition_method}_{client_num_in_total}.pickle', 'rb') as f:
    dataset = pickle.load(f)

In [4]:
def train_model(client, epochs):
    lr = 0.01
    wd = 0.001
    
    client.model.to(client.device)
    client.model.train()
    
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(client.model.parameters(), lr=lr,
                      momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
#     optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, client.model.parameters()), lr=lr,
#                                          weight_decay=wd, amsgrad=True)
#     optimizer = torch.optim.Adam(client.model.parameters(), lr=0.001, betas=(0.5, 0.9))
    epoch_loss = []
    for epoch in range(epochs):
        batch_loss = []
        for batch_idx, (x, labels) in enumerate(client.train_data):
            x, labels = x.to(device), labels.to(device)
            client.model.zero_grad()
            log_probs, _ = client.model(x)
            loss = criterion(log_probs, labels)
            loss.backward()

            # to avoid nan loss
            torch.nn.utils.clip_grad_norm_(client.model.parameters(), 1.0)

            optimizer.step()
            # logging.info('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            #     epoch, (batch_idx + 1) * args.batch_size, len(train_data) * args.batch_size,
            #            100. * (batch_idx + 1) / len(train_data), loss.item()))
            batch_loss.append(loss.item())
            
        scheduler.step()
        if epoch % 50 == 0:
            torch.save(client.model.state_dict(), str.format(save_model_path, client_num_in_total, partition_method, client.id, epoch))
        epoch_loss.append(sum(batch_loss) / len(batch_loss))
        print('Client Index = {}\tEpoch: {}\tLoss: {:.6f}'.format(
            client.id, epoch, sum(epoch_loss) / len(epoch_loss)))
        
    torch.save(client.model.state_dict(), str.format(save_model_path, client_num_in_total, partition_method, client.id, epochs))

class Client(object):
    def __init__(self, client_index, train_data_local_dict, train_data_local_num_dict, test_data_local_dict, device, model):
        self.id = client_index
        self.train_data = train_data_local_dict[self.id]
        self.local_sample_number = train_data_local_num_dict[self.id]
        self.test_local = test_data_local_dict[self.id]
        
        self.device = device
        self.model = model
    

# clients = []
# for i in range(3):
#     client = Client(i, train_data_local_dict, train_data_local_num_dict, test_data_local_dict, device, resnet56(class_num=class_num))

#     train_model(client, epochs)

In [5]:
model = Resnet56(class_num=dataset[-1], neck='no')
# [train_data_num, test_data_num, train_data_global, test_data_global, \
# #             train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
# #             class_num]
client = Client(0, dataset[5], dataset[4], dataset[6], device, model)
# model.load_state_dict(torch.load('model/cs_3_homo_client_0_oral_epochs_200.pt'))

In [6]:
train_model(client, 200)

Client Index = 0	Epoch: 0	Loss: 1.987793
Client Index = 0	Epoch: 1	Loss: 1.852640
Client Index = 0	Epoch: 2	Loss: 1.764082
Client Index = 0	Epoch: 3	Loss: 1.697759
Client Index = 0	Epoch: 4	Loss: 1.641455
Client Index = 0	Epoch: 5	Loss: 1.592085
Client Index = 0	Epoch: 6	Loss: 1.547518
Client Index = 0	Epoch: 7	Loss: 1.508914
Client Index = 0	Epoch: 8	Loss: 1.472901
Client Index = 0	Epoch: 9	Loss: 1.441112
Client Index = 0	Epoch: 10	Loss: 1.411756
Client Index = 0	Epoch: 11	Loss: 1.383939
Client Index = 0	Epoch: 12	Loss: 1.358197
Client Index = 0	Epoch: 13	Loss: 1.334200
Client Index = 0	Epoch: 14	Loss: 1.311707
Client Index = 0	Epoch: 15	Loss: 1.290816
Client Index = 0	Epoch: 16	Loss: 1.271332
Client Index = 0	Epoch: 17	Loss: 1.252163
Client Index = 0	Epoch: 18	Loss: 1.234288
Client Index = 0	Epoch: 19	Loss: 1.216818
Client Index = 0	Epoch: 20	Loss: 1.200671
Client Index = 0	Epoch: 21	Loss: 1.185388
Client Index = 0	Epoch: 22	Loss: 1.170292
Client Index = 0	Epoch: 23	Loss: 1.155842
Cl

In [24]:
def test(model, test_data, device):
    model.to(device)
    model.eval()

    metrics = {
        'test_correct': 0,
        'test_loss': 0,
        'test_total': 0
    }

    criterion = nn.CrossEntropyLoss().to(device)

    with torch.no_grad():
        for batch_idx, (x, target) in enumerate(test_data):
            x = x.to(device)
            target = target.to(device)
            pred, feat = model(x)
            print(feat.sum())
            loss = criterion(pred, target)

            _, predicted = torch.max(pred, -1)
            correct = predicted.eq(target).sum()

            metrics['test_correct'] += correct.item()
            metrics['test_loss'] += loss.item() * target.size(0)
            metrics['test_total'] += target.size(0)
            
    return metrics

metrics = test(client.model, dataset[3], device)
test_correct = metrics['test_correct']/metrics['test_total']
test_loss = metrics['test_loss']/metrics['test_total']
print(metrics['test_total'])
print(f'test_correct: {test_correct}; test_loss: {test_loss}')

tensor(1384288., device='cuda:3')
tensor(1395984.3750, device='cuda:3')
tensor(1387900.3750, device='cuda:3')
tensor(1389824., device='cuda:3')
tensor(1385503., device='cuda:3')
tensor(1432972.2500, device='cuda:3')
tensor(1424047.1250, device='cuda:3')
tensor(1391063.2500, device='cuda:3')
tensor(1404191.7500, device='cuda:3')
tensor(1403830.3750, device='cuda:3')
tensor(1406911., device='cuda:3')
tensor(1406319.3750, device='cuda:3')
tensor(1424981.6250, device='cuda:3')
tensor(1393080.2500, device='cuda:3')
tensor(1426188.2500, device='cuda:3')
tensor(1411552.3750, device='cuda:3')
tensor(1383061.7500, device='cuda:3')
tensor(1391260.7500, device='cuda:3')
tensor(1433140.3750, device='cuda:3')
tensor(1406451., device='cuda:3')
tensor(1402039.5000, device='cuda:3')
tensor(1421068.1250, device='cuda:3')
tensor(1427921.5000, device='cuda:3')
tensor(1400979.6250, device='cuda:3')
tensor(1404369.8750, device='cuda:3')
tensor(1413765.5000, device='cuda:3')
tensor(1390108.7500, device='cud

In [31]:
def cal_covariance(input):

    CovaMatrix_list = []
    mean_list = []
    for i in range(len(input)):
        support_set_sam = input[i]
#         print(support_set_sam.shape)
        support_set_sam = torch.unsqueeze(support_set_sam, 0)
        B, C, h, w = support_set_sam.size()
        

        support_set_sam = support_set_sam.permute(1, 0, 2, 3)
        
        support_set_sam = support_set_sam.contiguous().view(C, -1)
#         print(support_set_sam.shape)
        mean_support = torch.mean(support_set_sam, 1, True)
        mean_list.append(mean_support)

        support_set_sam = support_set_sam-mean_support

        covariance_matrix = support_set_sam@torch.transpose(support_set_sam, 0, 1)
        covariance_matrix = torch.div(covariance_matrix, h*w*B-1)
        CovaMatrix_list.append(covariance_matrix)

    return CovaMatrix_list, mean_list

cl = [torch.zeros((256,256)) for i in range(10)]
ml = [torch.zeros((256,64)) for i in range(10)]
lbd = 0.999
labels = []
def extract_features(model, data_loader, device):
    model.to(device)
    model.eval()
    
    features = []
    
    
    with torch.no_grad():
        for batch_idx, (x, l) in enumerate(data_loader):
            x, l = x.to(device), l.to(device)
            
            score, feats = model(x)
            covaM_list, mean_list = cal_covariance(feats)
            for covaM, f, label in zip(covaM_list, feats, l):
                labels.append(label.cpu())
                for i in range(len(cl)):
                    if label.data.cpu() == i:
#                         print(i)
                        cl[i] = lbd * cl[i] + (1-lbd) * covaM.cpu()
                        f = torch.unsqueeze(f, 0)
                        B, C, h, w = f.size()

                        f = f.permute(1, 0, 2, 3)
                        f = f.contiguous().view(C, -1)
#                         print(f.shape)
                        ml[i] = lbd * ml[i] + (1-lbd) * f.cpu()
                        
extract_features(client.model, dataset[2], device)
covaMs_means = [cl, ml]
# print(cl[0].shape)
# with open(f'better_orca_no_bnneck_l_covaMs_means.pickle', 'wb') as f:
#     pickle.dump(covaMs_means, f)

8
8
6
7
9
9
6
8
6
5
7
8
7
9
5
5
5
8
7
3
9
6
7
9
3
7
5
0
6
7
2
5
1
9
4
3
9
8
2
9
8
2
1
3
3
2
8
9
6
4
6
9
8
5
4
3
0
6
2
4
9
7
7
9
2
5
2
1
8
6
9
2
5
7
6
2
5
7
4
3
9
5
5
7
9
1
0
7
0
4
5
3
1
4
5
2
2
3
6
6
5
3
9
1
4
2
4
5
4
6
9
9
1
2
8
9
9
9
1
4
7
1
8
7
9
3
6
4
5
7
7
1
1
6
9
0
2
8
9
2
5
4
9
9
6
4
6
6
0
8
4
9
9
6
7
4
8
4
6
9
2
3
8
5
8
8
0
6
4
5
7
7
3
9
1
7
2
7
7
7
4
7
9
8
2
0
8
7
2
4
5
3
7
6
8
8
7
0
3
0
9
1
0
2
2
3
5
9
9
2
7
4
7
7
9
9
8
8
4
2
5
1
7
1
3
3
5
7
0
2
0
8
5
8
0
4
5
9
8
4
4
0
4
4
0
0
4
8
3
4
2
0
8
8
7
8
8
5
0
7
3
0
3
8
1
8
7
6
7
1
2
5
2
7
3
4
4
0
8
6
6
3
5
3
6
0
8
2
1
3
8
3
0
7
0
8
3
8
7
8
9
1
5
7
1
8
9
3
8
4
5
2
6
1
8
4
6
2
1
4
2
1
2
2
4
2
9
2
4
9
5
9
0
5
0
7
1
5
2
4
7
0
6
3
7
5
7
8
6
7
9
0
1
8
8
9
7
4
2
6
3
9
6
2
4
8
8
0
9
2
9
1
0
3
9
3
0
8
7
6
5
3
3
8
2
4
8
9
3
5
1
3
2
2
3
5
2
9
3
3
9
0
9
6
0
2
6
7
1
9
0
6
6
0
9
1
5
5
2
9
2
3
9
3
6
0
6
6
8
4
4
1
3
8
2
3
2
5
6
8
2
3
6
4
2
0
3
2
2
8
1
3
5
9
8
3
9
7
6
1
8
4
0
0
2
4
2
6
3
8
8
4
3
8
7
2
7
0
0
6
2
6
3
5
1
4
8
8
8
7
2
9
0
8
7
9
8
5
4
2


9
3
2
7
2
1
7
5
2
1
5
6
3
2
4
7
5
9
6
4
9
9
6
1
5
1
2
8
9
6
0
2
2
9
2
3
8
1
9
8
4
1
3
6
9
5
4
7
3
2
7
2
4
5
3
5
9
4
1
6
5
2
6
9
3
3
9
0
0
8
2
1
5
9
5
8
2
4
1
3
5
9
4
5
8
1
4
0
9
8
2
2
1
9
4
8
2
4
9
4
6
7
1
2
3
0
4
8
5
7
7
3
4
2
9
5
6
6
5
0
4
0
4
0
8
0
4
6
9
4
5
8
2
7
2
4
7
1
8
4
6
8
5
2
9
0
0
9
8
0
8
9
6
8
3
5
8
4
8
3
2
0
0
8
1
2
3
9
2
5
2
8
7
4
9
4
9
9
9
5
4
0
3
1
2
0
1
7
8
0
3
0
3
4
1
5
8
9
0
3
3
0
0
6
5
2
1
0
5
0
3
8
0
5
7
8
3
2
6
5
0
1
1
9
7
4
2
2
1
1
3
4
4
9
7
0
8
8
4
3
4
0
7
8
6
5
3
1
2
3
2
3
2
2
9
5
0
3
4
1
3
4
6
4
7
4
9
2
3
5
7
5
1
0
5
5
6
6
0
0
4
7
9
6
8
3
7
6
1
9
1
6
7
0
9
1
0
1
0
1
0
6
6
4
1
4
5
2
1
9
4
1
5
4
3
4
5
9
2
7
5
1
4
9
6
9
1
3
1
4
7
4
9
9
2
2
6
2
0
4
9
5
2
5
2
8
1
1
8
6
5
5
3
1
6
7
1
0
1
0
9
0
8
2
2
4
0
8
5
3
7
2
4
2
8
4
2
3
1
0
6
5
9
9
0
6
9
3
0
2
5
5
1
4
1
7
7
7
9
6
5
8
2
9
7
3
5
2
0
4
7
2
8
5
9
1
8
3
5
1
6
3
8
3
9
7
9
4
4
1
4
8
1
2
7
4
6
0
1
3
7
9
2
8
5
9
6
6
0
0
2
7
3
0
3
4
6
1
3
9
6
1
4
1
7
5
9
4
8
5
3
4
9
6
5
0
4
5
9
4
0
5
9
9
0
5
0
1
1
6
0
2
5
5
3
5
8
3
2
1


8
0
6
1
0
8
3
8
5
4
4
4
3
8
4
8
9
4
6
5
3
8
0
4
6
8
0
1
8
7
2
5
1
9
0
0
3
6
4
7
2
5
1
5
6
9
7
3
4
7
1
7
9
3
7
5
4
1
7
6
9
9
5
9
5
7
2
1
0
7
3
9
2
4
5
6
9
0
4
8
2
2
6
4
6
9
2
5
5
1
4
3
2
1
0
2
7
7
6
9
9
1
3
7
2
5
6
4
1
6
9
7
9
1
0
0
8
6
4
2
0
6
0
4
6
1
3
7
3
6
3
4
3
7
1
2
6
9
0
7
5
3
0
5
8
9
9
2
7
4
7
2
6
3
8
9
6
9
1
7
2
5
3
0
0
8
8
7
8
3
0
6
9
2
6
4
3
9
8
4
3
4
7
9
3
2
9
0
7
1
8
8
0
7
7
4
2
5
4
2
2
1
7
3
1
0
4
1
8
1
0
1
1
7
9
1
9
3
1
9
4
3
1
0
8
5
4
4
6
7
8
4
9
9
8
6
3
7
8
8
0
4
9
4
1
8
0
2
3
9
0
2
9
3
3
4
5
8
1
9
6
1
5
8
3
9
9
2
0
9
5
6
0
3
8
6
0
9
9
2
1
6
5
5
0
4
5
6
6
3
9
2
9
7
4
4
7
5
7
9
4
7
6
5
0
5
0
4
4
8
2
2
4
0
2
3
7
8
3
4
0
5
1
1
4
3
6
4
8
5
5
3
1
0
8
3
8
7
6
6
4
4
5
5
4
8
9
2
8
1
5
2
7
1
4
9
1
8
0
1
2
9
4
3
5
0
6
6
2
7
6
5
3
4
1
7
0
3
3
9
8
9
2
3
3
7
5
0
3
6
3
2
6
0
2
8
0
5
2
5
7
0
4
2
1
4
0
8
0
2
6
0
5
3
6
4
2
2
8
0
0
3
9
3
4
5
6
6
2
7
7
1
3
3
2
1
8
5
5
3
7
4
3
4
7
5
0
7
0
3
1
2
4
2
6
9
8
9
1
5
1
6
7
6
1
1
0
1
7
3
4
7
8
7
6
3
0
0
0
7
8
7
5
6
4
3
5
3
9
8
2
1
7
6
1
2
4
8
6
8


7
9
1
8
9
3
2
4
0
6
0
5
1
6
4
8
6
1
7
1
3
5
0
1
7
0
3
1
8
7
0
8
6
8
6
7
8
0
5
0
4
3
1
0
9
3
4
7
2
7
2
0
3
1
7
9
2
2
1
5
9
9
8
5
4
7
3
5
0
7
7
5
2
3
4
5
4
6
9
3
4
6
8
2
5
1
0
0
6
9
8
0
7
6
9
1
9
5
1
3
8
5
8
5
1
1
1
5
1
8
4
1
2
7
3
6
5
2
8
6
9
7
0
7
2
2
3
8
3
0
6
6
4
9
4
8
1
1
9
4
5
5
8
9
3
1
1
2
9
8
4
6
8
1
6
2
3
0
4
6
7
7
3
7
7
3
3
6
5
5
3
8
2
3
8
2
6
9
9
8
5
3
9
5
6
2
2
1
5
6
3
0
1
8
4
2
3
3
9
7
5
4
6
1
7
4
2
6
8
8
4
8
6
5
8
7
2
2
0
3
8
5
8
2
6
7
1
7
9
1
4
9
8
0
5
6
5
1
1
7
1
8
1
5
9
7
2
6
8
5
3
6
2
2
3
0
9
8
3
4
7
2
2
5
3
7
7
6
8
1
9
5
4
5
9
5
3
2
8
8
1
9
4
5
9
4
5
7
9
3
4
3
5
8
4
2
7
9
2
5
3
5
1
7
6
2
8
8
6
0
9
6
7
6
9
5
4
7
0
3
6
1
3
4
8
9
0
1
6
3
2
5
2
5
7
0
5
8
9
4
9
7
9
6
6
1
1
2
5
1
7
6
5
4
1
3
3
8
8
6
4
4
0
1
6
3
5
1
8
5
9
0
8
8
3
7
4
6
6
8
8
6
6
2
2
5
9
4
3
2
6
7
5
0
6
6
8
0
3
0
9
3
1
4
1
7
3
5
8
6
4
5
3
2
7
3
8
1
3
6
3
2
8
9
3
2
6
9
8
8
3
3
7
0
9
5
7
3
5
3
1
3
8
8
4
0
0
9
3
1
6
1
4
0
3
9
1
1
1
5
2
0
6
4
7
8
6
3
7
8
2
5
0
9
0
9
6
0
0
6
2
9
9
6
2
6
0
2
8
6
1
3
4
5
1
6
1
7
3
6


KeyboardInterrupt: 

In [10]:
model1 = CovaMResnet56(class_num=dataset[-1], neck='no', with_cova=True)

In [20]:
a = torch.ones((10, 256))
for key in model1.state_dict().keys():
    print(key)
    if key == 'ce_classifier.weight':
        t = model1.state_dict()[key]
        print(t)
        t.copy_(a)
        t.require_grad = False
    if key == 'ce_classifier.bias':
        t = model1.state_dict()[key]
        t.copy_(torch.zeros(t.shape))
model1.ce_classifier.requires_grad = False
        
for key in model1.state_dict().keys():
    if key == 'ce_classifier.weight':
        t = model1.state_dict()[key]
        print(t)

base.conv1.weight
base.bn1.weight
base.bn1.bias
base.bn1.running_mean
base.bn1.running_var
base.bn1.num_batches_tracked
base.layer1.0.conv1.weight
base.layer1.0.bn1.weight
base.layer1.0.bn1.bias
base.layer1.0.bn1.running_mean
base.layer1.0.bn1.running_var
base.layer1.0.bn1.num_batches_tracked
base.layer1.0.conv2.weight
base.layer1.0.bn2.weight
base.layer1.0.bn2.bias
base.layer1.0.bn2.running_mean
base.layer1.0.bn2.running_var
base.layer1.0.bn2.num_batches_tracked
base.layer1.0.conv3.weight
base.layer1.0.bn3.weight
base.layer1.0.bn3.bias
base.layer1.0.bn3.running_mean
base.layer1.0.bn3.running_var
base.layer1.0.bn3.num_batches_tracked
base.layer1.0.downsample.0.weight
base.layer1.0.downsample.1.weight
base.layer1.0.downsample.1.bias
base.layer1.0.downsample.1.running_mean
base.layer1.0.downsample.1.running_var
base.layer1.0.downsample.1.num_batches_tracked
base.layer1.1.conv1.weight
base.layer1.1.bn1.weight
base.layer1.1.bn1.bias
base.layer1.1.bn1.running_mean
base.layer1.1.bn1.running_

In [21]:
metrics = test(client.model, dataset[3], device)
test_correct = metrics['test_correct']/metrics['test_total']
test_loss = metrics['test_loss']/metrics['test_total']
print(metrics['test_total'])
print(f'test_correct: {test_correct}; test_loss: {test_loss}')

10000
test_correct: 0.863; test_loss: 0.4603449203073978
