In [2]:

#!/usr/bin/env python
import copy
import torch
import argparse
import os
import time
import warnings
import numpy as np
import torchvision
import logging

from flcore.servers.serveravg import FedAvg
from flcore.servers.serverpFedMe import pFedMe
from flcore.servers.serverperavg import PerAvg
from flcore.servers.serverprox import FedProx
from flcore.servers.serverfomo import FedFomo
from flcore.servers.serveramp import FedAMP
from flcore.servers.servermtl import FedMTL
from flcore.servers.serverlocal import Local
from flcore.servers.serverper import FedPer
from flcore.servers.serverapfl import APFL
from flcore.servers.serverditto import Ditto
from flcore.servers.serverrep import FedRep
from flcore.servers.serverphp import FedPHP
from flcore.servers.serverbn import FedBN
from flcore.servers.serverrod import FedROD
from flcore.servers.serverproto import FedProto
from flcore.servers.serverdyn import FedDyn
from flcore.servers.servermoon import MOON
from flcore.servers.serverbabu import FedBABU
from flcore.servers.serverapple import APPLE
from flcore.servers.serverfedtrans import FedTrans 

from flcore.trainmodel.models import *

from flcore.trainmodel.bilstm import BiLSTM_TextClassification
# from flcore.trainmodel.resnet import resnet18 as resnet
from flcore.trainmodel.alexnet import alexnet
from flcore.trainmodel.mobilenet_v2 import mobilenet_v2
from utils.result_utils import average_data
from utils.mem_utils import MemReporter

logger = logging.getLogger()
logger.setLevel(logging.ERROR)

warnings.simplefilter("ignore")
torch.manual_seed(0)

# hyper-params for Text tasks
vocab_size = 98635
max_len=200
hidden_dim=32

def run(args):

    time_list = []
    reporter = MemReporter()
    model_str = args.model

    for i in range(args.prev, args.times):
        print(f"\n============= Running time: {i}th =============")
        print("Creating server and clients ...")
        start = time.time()

        # Generate args.model
        if model_str == "mlr":
            if args.dataset == "mnist" or args.dataset == "fmnist":
                args.model = Mclr_Logistic(1*28*28, num_classes=args.num_classes).to(args.device)
            elif args.dataset == "Cifar10" or args.dataset == "Cifar100":
                args.model = Mclr_Logistic(3*32*32, num_classes=args.num_classes).to(args.device)
            else:
                args.model = Mclr_Logistic(60, num_classes=args.num_classes).to(args.device)

        elif model_str == "cnn":
            if args.dataset[:5] == "mnist" or args.dataset == "fmnist":
                args.model = FedAvgCNN(in_features=1, num_classes=args.num_classes, dim=1024).to(args.device)
            elif args.dataset == "omniglot":
                args.model = FedAvgCNN(in_features=1, num_classes=args.num_classes, dim=33856).to(args.device)
            elif args.dataset[:5] == "Cifar":
                args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
                # args.model = CifarNet(num_classes=args.num_classes).to(args.device)
            elif args.dataset == "Digit5":
                args.model = Digit5CNN().to(args.device)
            else:
                args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=10816).to(args.device)

        elif model_str == "dnn": # non-convex
            if args.dataset == "mnist" or args.dataset == "fmnist":
                args.model = DNN(1*28*28, 100, num_classes=args.num_classes).to(args.device)
            elif args.dataset == "Cifar10" or args.dataset == "Cifar100":
                args.model = DNN(3*32*32, 100, num_classes=args.num_classes).to(args.device)
            else:
                args.model = DNN(60, 20, num_classes=args.num_classes).to(args.device)
        
        elif model_str == "resnet":
            args.model = torchvision.models.resnet18(pretrained=False, num_classes=args.num_classes).to(args.device)
            
            # args.model = torchvision.models.resnet18(pretrained=True).to(args.device)
            # feature_dim = list(args.model.fc.parameters())[0].shape[1]
            # args.model.fc = nn.Linear(feature_dim, args.num_classes).to(args.device)
            
            # args.model = resnet18(num_classes=args.num_classes, has_bn=True, bn_block_num=4).to(args.device)

        elif model_str == "alexnet":
            args.model = alexnet(pretrained=False, num_classes=args.num_classes).to(args.device)
            
            # args.model = alexnet(pretrained=True).to(args.device)
            # feature_dim = list(args.model.fc.parameters())[0].shape[1]
            # args.model.fc = nn.Linear(feature_dim, args.num_classes).to(args.device)
            
        elif model_str == "googlenet":
            args.model = torchvision.models.googlenet(pretrained=False, aux_logits=False, num_classes=args.num_classes).to(args.device)
            
            # args.model = torchvision.models.googlenet(pretrained=True, aux_logits=False).to(args.device)
            # feature_dim = list(args.model.fc.parameters())[0].shape[1]
            # args.model.fc = nn.Linear(feature_dim, args.num_classes).to(args.device)

        elif model_str == "mobilenet_v2":
            args.model = mobilenet_v2(pretrained=False, num_classes=args.num_classes).to(args.device)
            
            # args.model = mobilenet_v2(pretrained=True).to(args.device)
            # feature_dim = list(args.model.fc.parameters())[0].shape[1]
            # args.model.fc = nn.Linear(feature_dim, args.num_classes).to(args.device)
            
        elif model_str == "lstm":
            args.model = LSTMNet(hidden_dim=hidden_dim, vocab_size=vocab_size, num_classes=args.num_classes).to(args.device)

        elif model_str == "bilstm":
            args.model = BiLSTM_TextClassification(input_size=vocab_size, hidden_size=hidden_dim, output_size=args.num_classes, 
                        num_layers=1, embedding_dropout=0, lstm_dropout=0, attention_dropout=0, 
                        embedding_length=hidden_dim).to(args.device)

        elif model_str == "fastText":
            args.model = fastText(hidden_dim=hidden_dim, vocab_size=vocab_size, num_classes=args.num_classes).to(args.device)

        elif model_str == "TextCNN":
            args.model = TextCNN(hidden_dim=hidden_dim, max_len=max_len, vocab_size=vocab_size, 
                            num_classes=args.num_classes).to(args.device)

        elif model_str == "Transformer":
            args.model = TransformerModel(ntoken=vocab_size, d_model=hidden_dim, nhead=2, d_hid=hidden_dim, nlayers=2, 
                            num_classes=args.num_classes).to(args.device)
        
        elif model_str == "AmazonMLP":
            args.model = AmazonMLP().to(args.device)

        else:
            raise NotImplementedError

        print(args.model)

        # select algorithm
        if args.algorithm == "FedAvg":
            server = FedAvg(args, i)

        elif args.algorithm == "Local":
            server = Local(args, i)

        elif args.algorithm == "FedMTL":
            server = FedMTL(args, i)

        elif args.algorithm == "PerAvg":
            server = PerAvg(args, i)

        elif args.algorithm == "pFedMe":
            server = pFedMe(args, i)

        elif args.algorithm == "FedProx":
            server = FedProx(args, i)

        elif args.algorithm == "FedFomo":
            server = FedFomo(args, i)

        elif args.algorithm == "FedAMP":
            server = FedAMP(args, i)

        elif args.algorithm == "APFL":
            server = APFL(args, i)

        elif args.algorithm == "FedPer":
            args.head = copy.deepcopy(args.model.fc)
            args.model.fc = nn.Identity()
            args.model = LocalModel(args.model, args.head)
            server = FedPer(args, i)

        elif args.algorithm == "Ditto":
            server = Ditto(args, i)

        elif args.algorithm == "FedRep":
            args.head = copy.deepcopy(args.model.fc)
            args.model.fc = nn.Identity()
            args.model = LocalModel(args.model, args.head)
            server = FedRep(args, i)

        elif args.algorithm == "FedPHP":
            args.head = copy.deepcopy(args.model.fc)
            args.model.fc = nn.Identity()
            args.model = LocalModel(args.model, args.head)
            server = FedPHP(args, i)

        elif args.algorithm == "FedBN":
            server = FedBN(args, i)

        elif args.algorithm == "FedROD":
            args.head = copy.deepcopy(args.model.fc)
            args.model.fc = nn.Identity()
            args.model = LocalModel(args.model, args.head)
            server = FedROD(args, i)

        elif args.algorithm == "FedProto":
            args.head = copy.deepcopy(args.model.fc)
            args.model.fc = nn.Identity()
            args.model = LocalModel(args.model, args.head)
            server = FedProto(args, i)

        elif args.algorithm == "FedDyn":
            server = FedDyn(args, i)

        elif args.algorithm == "MOON":
            args.head = copy.deepcopy(args.model.fc)
            args.model.fc = nn.Identity()
            args.model = LocalModel(args.model, args.head)
            server = MOON(args, i)

        elif args.algorithm == "FedBABU":
            args.head = copy.deepcopy(args.model.fc)
            args.model.fc = nn.Identity()
            args.model = LocalModel(args.model, args.head)
            server = FedBABU(args, i)

        elif args.algorithm == "APPLE":
            server = APPLE(args, i)
            
        elif args.algorithm == "FedTrans":
            args.head = copy.deepcopy(args.model.fc)
            args.model.fc = nn.Identity()
            args.model = LocalModel(args.model, args.head)
            server = FedTrans(args, i)
            
        else:
            raise NotImplementedError
    
    return server

if __name__ == "__main__":
    total_start = time.time()

    parser = argparse.ArgumentParser()
    # general
    parser.add_argument('-go', "--goal", type=str, default="test", 
                        help="The goal for this experiment")
    parser.add_argument('-dev', "--device", type=str, default="cuda",
                        choices=["cpu", "cuda"])
    parser.add_argument('-did', "--device_id", type=str, default="0")
    parser.add_argument('-data', "--dataset", type=str, default="mnist")
    parser.add_argument('-nb', "--num_classes", type=int, default=10)
    parser.add_argument('-m', "--model", type=str, default="cnn")
    parser.add_argument('-p', "--head", type=str, default="cnn")
    parser.add_argument('-lbs', "--batch_size", type=int, default=10)
    parser.add_argument('-lr', "--local_learning_rate", type=float, default=0.005,
                        help="Local learning rate")
    parser.add_argument('-gr', "--global_rounds", type=int, default=1000)
    parser.add_argument('-ls', "--local_steps", type=int, default=1)
    parser.add_argument('-algo', "--algorithm", type=str, default="FedAvg")
    parser.add_argument('-jr', "--join_ratio", type=float, default=1.0,
                        help="Ratio of clients per round")
    parser.add_argument('-rjr', "--random_join_ratio", type=bool, default=False,
                        help="Random ratio of clients per round")
    parser.add_argument('-nc', "--num_clients", type=int, default=2,
                        help="Total number of clients")
    parser.add_argument('-pv', "--prev", type=int, default=0,
                        help="Previous Running times")
    parser.add_argument('-t', "--times", type=int, default=1,
                        help="Running times")
    parser.add_argument('-eg', "--eval_gap", type=int, default=1,
                        help="Rounds gap for evaluation")
    parser.add_argument('-dp', "--privacy", type=bool, default=False,
                        help="differential privacy")
    parser.add_argument('-dps', "--dp_sigma", type=float, default=0.0)
    parser.add_argument('-sfn', "--save_folder_name", type=str, default='models')
    # practical
    parser.add_argument('-cdr', "--client_drop_rate", type=float, default=0.0,
                        help="Rate for clients that train but drop out")
    parser.add_argument('-tsr', "--train_slow_rate", type=float, default=0.0,
                        help="The rate for slow clients when training locally")
    parser.add_argument('-ssr', "--send_slow_rate", type=float, default=0.0,
                        help="The rate for slow clients when sending global model")
    parser.add_argument('-ts', "--time_select", type=bool, default=False,
                        help="Whether to group and select clients at each round according to time cost")
    parser.add_argument('-tth', "--time_threthold", type=float, default=10000,
                        help="The threthold for droping slow clients")
    # pFedMe / PerAvg / FedProx / FedAMP / FedPHP
    parser.add_argument('-bt', "--beta", type=float, default=0.0,
                        help="Average moving parameter for pFedMe, Second learning rate of Per-FedAvg, \
                        or L1 regularization weight of FedTransfer")
    parser.add_argument('-lam', "--lamda", type=float, default=1.0,
                        help="Regularization weight for pFedMe and FedAMP")
    parser.add_argument('-mu', "--mu", type=float, default=0,
                        help="Proximal rate for FedProx")
    parser.add_argument('-K', "--K", type=int, default=5,
                        help="Number of personalized training steps for pFedMe")
    parser.add_argument('-lrp', "--p_learning_rate", type=float, default=0.01,
                        help="personalized learning rate to caculate theta aproximately using K steps")
    # FedFomo
    parser.add_argument('-M', "--M", type=int, default=5,
                        help="Server only sends M client models to one client at each round")
    # FedMTL
    parser.add_argument('-itk', "--itk", type=int, default=4000,
                        help="The iterations for solving quadratic subproblems")
    # FedAMP
    parser.add_argument('-alk', "--alphaK", type=float, default=1.0, 
                        help="lambda/sqrt(GLOABL-ITRATION) according to the paper")
    parser.add_argument('-sg', "--sigma", type=float, default=1.0)
    # APFL
    parser.add_argument('-al', "--alpha", type=float, default=1.0)
    # Ditto / FedRep
    parser.add_argument('-pls', "--plocal_steps", type=int, default=1)
    # MOON
    parser.add_argument('-ta', "--tau", type=float, default=1.0)
    # FedBABU
    parser.add_argument('-fts', "--fine_tuning_steps", type=int, default=1)
    # APPLE
    parser.add_argument('-dlr', "--dr_learning_rate", type=float, default=0.0)
    parser.add_argument('-L', "--L", type=float, default=1.0)
    #FedTrans
    parser.add_argument('-ere', "--every_recluster_eps", type=int, default=5)
    parser.add_argument('-ed', "--emb_dim", type=int, default=128)
    parser.add_argument('-alr', "--attn_learning_rate", type=float, default=0.005)
    parser.add_argument('-ncl', "--num_cluster", type=int, default=10)

    
    args = parser.parse_args(args=["-data","mnist", "-m", "cnn", -algo FedTrans -gr 2500 -did 0 -go cnn -nc 2")
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id

    if args.device == "cuda" and not torch.cuda.is_available():
        print("\ncuda is not avaiable.\n")
        args.device = "cpu"

    print("=" * 50)

    print("Algorithm: {}".format(args.algorithm))
    print("Local batch size: {}".format(args.batch_size))
    print("Local steps: {}".format(args.local_steps))
    print("Local learing rate: {}".format(args.local_learning_rate))
    print("Total number of clients: {}".format(args.num_clients))
    print("Clients join in each round: {}".format(args.join_ratio))
    print("Client drop rate: {}".format(args.client_drop_rate))
    print("Time select: {}".format(args.time_select))
    print("Time threthold: {}".format(args.time_threthold))
    print("Global rounds: {}".format(args.global_rounds))
    print("Running times: {}".format(args.times))
    print("Dataset: {}".format(args.dataset))
    print("Local model: {}".format(args.model))
    print("Using device: {}".format(args.device))

    if args.device == "cuda":
        print("Cuda device id: {}".format(os.environ["CUDA_VISIBLE_DEVICES"]))
    print("=" * 50)


    # if args.dataset == "mnist" or args.dataset == "fmnist":
    #     generate_mnist('../dataset/mnist/', args.num_clients, 10, args.niid)
    # elif args.dataset == "Cifar10" or args.dataset == "Cifar100":
    #     generate_cifar10('../dataset/Cifar10/', args.num_clients, 10, args.niid)
    # else:
    #     generate_synthetic('../dataset/synthetic/', args.num_clients, 10, args.niid)

    # with torch.profiler.profile(
    #     activities=[
    #         torch.profiler.ProfilerActivity.CPU,
    #         torch.profiler.ProfilerActivity.CUDA],
    #     profile_memory=True, 
    #     on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
    #     ) as prof:
    # with torch.autograd.profiler.profile(profile_memory=True) as prof:
    server = run(args)

    
    # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
    # print(f"\nTotal time cost: {round(time.time()-total_start, 2)}s.")


usage: ipykernel_launcher.py [-h] [-go GOAL] [-dev {cpu,cuda}]
                             [-did DEVICE_ID] [-data DATASET]
                             [-nb NUM_CLASSES] [-m MODEL] [-p HEAD]
                             [-lbs BATCH_SIZE] [-lr LOCAL_LEARNING_RATE]
                             [-gr GLOBAL_ROUNDS] [-ls LOCAL_STEPS]
                             [-algo ALGORITHM] [-jr JOIN_RATIO]
                             [-rjr RANDOM_JOIN_RATIO] [-nc NUM_CLIENTS]
                             [-pv PREV] [-t TIMES] [-eg EVAL_GAP]
                             [-dp PRIVACY] [-dps DP_SIGMA]
                             [-sfn SAVE_FOLDER_NAME] [-cdr CLIENT_DROP_RATE]
                             [-tsr TRAIN_SLOW_RATE] [-ssr SEND_SLOW_RATE]
                             [-ts TIME_SELECT] [-tth TIME_THRETHOLD]
                             [-bt BETA] [-lam LAMDA] [-mu MU] [-K K]
                             [-lrp P_LEARNING_RATE] [-M M] [-itk ITK]
                             [-alk ALPHAK] [-sg

SystemExit: 2

In [16]:
import torch
res = torch.load("cel.pt")
emb_list, weights = res
print(emb_list, weights)

[tensor([[-1.9877e-02,  2.7962e-02,  2.6841e-02,  9.4542e-03,  2.0430e-02,
          1.8452e-02,  2.3748e-02, -1.5183e-03, -6.0082e-03, -1.6457e-02,
          2.3891e-02,  1.3427e-02,  1.0454e-02, -2.1478e-02, -9.7432e-03,
         -1.8438e-02,  3.3368e-02,  3.0730e-02,  3.9423e-02, -5.2264e-03,
          1.8984e-02,  2.5925e-02,  1.6223e-02,  3.3737e-02,  9.2517e-04,
          3.0048e-03,  1.4591e-02, -8.9317e-04,  1.3517e-02,  1.3932e-02,
         -1.6047e-02,  2.6537e-02, -1.7544e-02, -8.0805e-03,  9.3743e-03,
          2.2606e-02,  5.7925e-03, -1.6918e-02,  2.9809e-02,  2.0877e-02,
         -2.3509e-02, -2.1547e-02, -2.0066e-02,  2.8955e-02, -6.8141e-03,
          1.4123e-02, -2.0192e-02, -2.7565e-02, -7.6179e-03,  1.0307e-02,
          1.6761e-02,  8.9698e-03, -2.5151e-02, -1.5949e-02,  5.4604e-02,
         -3.8716e-03,  9.1014e-03, -7.1127e-03, -7.1921e-03,  1.5355e-02,
          1.0121e-02,  2.2900e-02, -5.8777e-03,  2.0517e-02, -2.2348e-02,
          5.3340e-04,  1.5209e-02, -1

In [17]:
print(weights.size())

torch.Size([10, 10])


In [18]:

x = torch.cat(emb_list, dim=0).squeeze(1)
print(x.size())

torch.Size([10, 128])


In [24]:
from flcore.servers.serverfedtrans import Attn_Model
device = "cuda:0"
x.to(device)
attn_model = Attn_Model().to(device)
weights = attn_model(x)

In [None]:
print(weights)

In [88]:
res = torch.load("res.pt")

In [89]:
iter_e, iter_w = res['inter_clusters_res']

In [90]:
print(iter_e[1].size())

x = torch.cat(iter_e, dim=0).squeeze(1)

torch.Size([1, 128])


In [91]:
print(x)

tensor([[-0.0129, -0.0057, -0.0055,  ..., -0.0217,  0.0069,  0.0132],
        [-0.0129, -0.0057, -0.0055,  ..., -0.0217,  0.0069,  0.0132],
        [-0.0129, -0.0057, -0.0055,  ..., -0.0217,  0.0069,  0.0132],
        ...,
        [-0.0129, -0.0057, -0.0055,  ..., -0.0217,  0.0069,  0.0132],
        [-0.0129, -0.0057, -0.0055,  ..., -0.0217,  0.0069,  0.0132],
        [-0.0129, -0.0057, -0.0055,  ..., -0.0217,  0.0069,  0.0132]],
       device='cuda:0', grad_fn=<SqueezeBackward1>)


In [93]:
import torch
from torch import nn

class Attn_Model_C(nn.Module):
    def __init__(self, emb_dim=128, attn_dim=128, num_heads=8):
        super(Attn_Model_C, self).__init__()
        self.emb_dim = emb_dim
        self.attn_dim = attn_dim
        self.query = nn.Linear(emb_dim, attn_dim)
        self.key = nn.Linear(emb_dim, attn_dim)
        #self.inter_LN = nn.LayerNorm(attn_dim)

        # 1-layer attention for simple verify

    def forward(self, x, models=None, prev_models=None):
        #x = self.inter_LN(x) 
        q = self.query(x)

        k = self.key(x)
        print("q:{}\n{}\nk:{}".format(q,"-"*5,k))
        scores = torch.matmul(q, k.transpose(-2, -1)) 
        #scores = torch.matmul(q, k.transpose(-2, -1)) / (self.attn_dim ** 0.2)
        print(scores)
        attention_weights = torch.softmax(scores, dim=-1)
        return attention_weights

In [94]:

attn_model_1 = Attn_Model_C().to(device)
w = attn_model(x.to(device))

print(w)

tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000]], device='c

In [95]:
for c in res['intra_clusters_res']:
    if c is not None:
        c_e = c[0]
    else:
        continue
    
    x = torch.cat(c_e, dim=0).squeeze(1)
    w = attn_model_1(x.to(device))
    print(w)

q:tensor([[-0.0525, -0.0566,  0.0640,  0.0268, -0.0741,  0.0886,  0.0076,  0.0588,
         -0.0354, -0.0399, -0.0119,  0.0477,  0.0801, -0.0657,  0.0346,  0.0919,
         -0.0823, -0.0304, -0.0241,  0.0121, -0.0761,  0.0020,  0.0920,  0.0116,
          0.0271, -0.0015,  0.0399, -0.0113, -0.0653,  0.0036,  0.0314, -0.0638,
          0.0260,  0.0091, -0.0882,  0.0576,  0.0774, -0.0822,  0.0228, -0.0141,
          0.0688, -0.0367, -0.0510, -0.0217,  0.0721,  0.0554, -0.0007,  0.0497,
         -0.0603,  0.0613,  0.0568,  0.0223, -0.0654, -0.0129, -0.0203, -0.0176,
          0.0715, -0.0241, -0.0283,  0.0038, -0.0089, -0.0181, -0.0246,  0.0115,
          0.0779,  0.0044,  0.0549,  0.0131,  0.0880, -0.0409, -0.0394,  0.0348,
          0.0149,  0.0231, -0.0773,  0.0203, -0.0754,  0.0344, -0.0344, -0.0673,
          0.0095, -0.0328,  0.0609,  0.0574, -0.0101,  0.0058,  0.0222,  0.0608,
          0.0445,  0.0434, -0.0725, -0.0230, -0.0703,  0.0220, -0.0766,  0.0596,
         -0.0349, -0.0171,

In [60]:

print(res['intra_clusters_res'][1])

[[tensor([[-0.0147,  0.0313,  0.0258,  0.0092,  0.0214,  0.0190,  0.0234, -0.0043,
         -0.0071, -0.0138,  0.0249,  0.0151,  0.0096, -0.0220, -0.0082, -0.0208,
          0.0327,  0.0313,  0.0388, -0.0004,  0.0164,  0.0258,  0.0185,  0.0288,
          0.0043,  0.0055,  0.0126, -0.0058,  0.0130,  0.0121, -0.0192,  0.0276,
         -0.0192, -0.0082,  0.0072,  0.0216,  0.0122, -0.0220,  0.0275,  0.0229,
         -0.0188, -0.0218, -0.0236,  0.0266, -0.0034,  0.0145, -0.0166, -0.0270,
         -0.0081,  0.0072,  0.0143,  0.0110, -0.0283, -0.0150,  0.0533, -0.0029,
          0.0148, -0.0068, -0.0048,  0.0139,  0.0098,  0.0208, -0.0031,  0.0192,
         -0.0191,  0.0043,  0.0133, -0.0158,  0.0008,  0.0148, -0.0113, -0.0197,
         -0.0134,  0.0027, -0.0089, -0.0014,  0.0095, -0.0071,  0.0130, -0.0035,
          0.0129, -0.0388,  0.0110,  0.0120, -0.0080, -0.0311,  0.0080,  0.0009,
          0.0069,  0.0059,  0.0041,  0.0078, -0.0068, -0.0019, -0.0114, -0.0383,
          0.0061,  0.0020,