In [5]:
# this is the total pipeline for the project
# This file is for trainning
# Run this on the server, or as we called offline. 

# with gated, no generator
from Dataloaders.dataloader_cifar10 import Dataloader_cifar10
import argparse
import torch
import os
import sys
import numpy as np
import time
from Models import mobilenetv2
from Utils import utils
from Models import gatedmodel
from Models import generator
from tqdm import tqdm

def main(args):
    """
    # initial using mobilenetV2, and cifar10
    # we need a if statement here to decide which model and dataset to use
    # random_seed = 2024

    # for training, it is for training the generator 
    # recall the graph, when we cut more features, the performance should be worse.
    
    # get the training loader, freeze the model. Where is the partitioning point? 
    """

    # 1. get the train, test and val datasets, and labels.
    if args.dataset == 'cifar-10':
        # return train, test, val, labels, these are all dataloaders
        _, test, classes = Dataloader_cifar10(train_batch=128, test_batch=args.tb, seed=2024)
    elif args.dataset == 'cifar-100':
        pass
    
    # 2. transfer the dataset to fit the model, for the training, client and server model are all on the server
    if args.model == 'mobilenetV2':
        client, server = mobilenetv2.stupid_model_splitter(weight_path='./Weights/cifar-10/model/MobileNetV2.pth')
    elif args.model == 'resnet':
        pass

    # 3. get the gating, gating here decides how many channels are transferred to the server
    # simple version: a binary tree, complex version: model
    # get a ranker to rank the channels, and get the top k channels
    ranker = utils.ranker_entropy # input: embs, percentage, output: s_emb, s_ind

    # 4. get the gated model, we have 3 models here
    gate_name = args.gatename
    gated_rates = [0.25, 0.5, 0.75]
    channel2ind = {8:0, 16:1, 24:2, 32:3}
    input_channel = 32
    gateds = []
    for i in range(3):
        gateds.append(gatedmodel.model_list[gate_name](
            input_size=int(input_channel*gated_rates[i]),
            width=32,
            height=32,
            output_size=1)) #  input_size, weight, height, output_size=10
        # s_time = utils.get_latest_weights(args.dataset, gate_name, 'gate')
        # s_time = ('_').join(['0']*6)+'.pth'
        s_time = '2024_06_14_17_46_36'
        gateds[i].load_state_dict(torch.load('./Weights/'+args.dataset+'/gate/'+s_time+'/'+gate_name+'_'+str(i)+'_'+s_time+'.pth'))

    # 5. get the generator
    generators = []
    for i in range(3):
        generators.append(generator.Generator(
            inputsize=int(input_channel*gated_rates[i]), 
            hiddensize=32, 
            outputsize=32)) # inputsize, hiddensize, outputsize
    """
    # 6. get the server 
    # server = some_model_function()
    # server model is got above

    # pipline data -> dataloader -> client -> gating -> reducer -> generator -> server
    
    # cuda may not have enough space for putting all the models
    # client = client.cuda()
    # server = server.cuda()
    # for i in range(3):
    #     gateds[i] = gateds[i].cuda()
    #     generators[i] = generators[i].cuda() """  

    # set them to eval
    client.eval()
    server.eval()
    for i in range(3):
        gateds[i].eval()
        generators[i].eval()

    # load the test data set the test
    with torch.no_grad():
        correct = 0
        total = 0
        t_conf = 0 # total confidence
        t_zero = 0 # total number of 0
        c_gate = [0]*4
        f_gate = [0]*4
        client = client.cuda()
        server = server.cuda()
        for i, data in tqdm(enumerate(test)):
            images, labels = data
            images, labels = images.cuda(), labels.cuda()
            # first, run the client model on the iot
            out = client(images)
            # calcualte the total number of 0
            # t_zero += torch.sum(out == 0).item()/32/32/32
            # second, run the ranker
            ex_flag = args.gated
            ex_gate = 0
            # print('The number of 0 in the tensor is: ', torch.sum(out == 0).item())
            # if the exit_flag is a True
            if type(ex_flag) == bool:
                while ex_flag:
                    # get the emb and ind from the ranker
                    out = out.cpu()
                    s_emb, s_ind = ranker(out, gated_rates[ex_gate])
                    s_emb = s_emb.cuda()
                    # give it to the gated model
                    cur_gated = gateds[ex_gate].cuda() # b,e',h,w -> b,1 (softmax value)
                    g_conf = torch.mean(cur_gated(s_emb)).item() # b, n
                    # print('The confidence is: ', g_conf)
                    l_thres = args.threshold * np.exp(ex_flag*1/3)
                    if g_conf > l_thres:
                        ex_flag = not ex_flag
                    else:
                        ex_gate += 1

                    if ex_gate == 3:
                        break
                # check the exit flag and send
                if ex_flag != args.gated:
                    # print('The choosen gate is: ', ex_gate)
                    out = s_emb

            # if the exit_flag is a number
            elif type(ex_flag) == int:
                out = out.cpu()
                s_emb, s_ind = ranker(out, gated_rates[ex_flag])
                s_emb = s_emb.cuda()
                cur_gated = gateds[ex_flag].cuda()
                g_conf = torch.mean(cur_gated(s_emb)).item() # b, n
                # print('The confidence is: ', g_conf.item())
                out = s_emb # b, c, h, w
                # print the number of 0 in the tensor)
                # print('The number of 0 in the tensor is: ', torch.sum(out == 0).item())
                # return 0

            # a sender here, but on server, we don't have it.
            
            # a receiver here, but on server, we don't have it.
            t_conf += g_conf
            c_gate[ex_gate] += 1
            f_gate[ex_gate] += g_conf
            # get the generator
            rec_size = out.size(1)
            # print('The size of the tensor is: ', rec_size)
            if ex_flag != args.gated or type(args.gated) == int: # int !=
                rec_ind = s_ind
            rec_size2ind = channel2ind[rec_size]
            # if we don't get all features, we need to use the generator
            if args.generator:
                if rec_size2ind != 3:
                    cur_gen = generators[rec_size2ind]
                    # load weights
                    cur_gen.load_state_dict(torch.load('./Weights/cifar-10/generator/generator_'+str(rec_size2ind)+'.pth'))
                    cur_gen = cur_gen.cuda()
                    out = cur_gen(out)

            # skip the generator, create a tensor with all zeros
            else:
                if ex_flag != args.gated or type(args.gated) == int:
                    n_out = torch.zeros(out.size(0), 32, 32, 32).cuda()
                    n_out[:, rec_ind, :, :] = out
                    out = n_out
            # run the server model
            out = out.cuda()
            out = server(out)
            _, predicted = torch.max(out.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            # print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

        print('The average confidence is: ', 100 * t_conf/total)
        print('The average confidence for each gate is: %.4f, %.4f, %.4f, %.4f' % (f_gate[0]/max(c_gate[0], 1), \
                                                                             f_gate[1]/max(c_gate[1], 1), f_gate[2]/max(c_gate[2], 1), f_gate[3]/max(c_gate[3], 1)))
        print('The average accuracy is: ', 100 * correct / total)
        # print('The average number of 0 is: ', t_zero/total)   
        print('Gate chosen: %d, %d, %d, %d' % (c_gate[0], c_gate[1], c_gate[2], c_gate[3]))

# if __name__ == '__main__':
#     print('enter')
#     parser = argparse.ArgumentParser()
#     # we need the name of model, the name of dataset
#     parser.add_argument('--dataset', type=str, default='cifar10', help='name of dataset')
#     # parser.add_argument('--iot_model', type=str, default='mobilenetV2', help='name of the model on the iot')
#     parser.add_argument('--reducer', type=str, default='entrophy', help='name of the reducer')
#     parser.add_argument('--client', type=str, default='LTE', help='name of the network condition on the client side')
#     parser.add_argument('--server', type=str, default='LTE', help='name of the network condition on the server side')
#     parser.add_argument('--generator', type=str, default='None', help='name of the generator')
#     # parser.add_argument('--server', type=str, default='mobilenetV2', help='name of the model on the server, should be the same as it on the iot')
#     parser.add_argument('--device', type=str, default='home', help='run on which device, home, tintin, rpi, pico, jetson?')
#     parser.add_argument('--model', type=str, default='mobilenetV2', help='name of the model')
#     args = parser.parse_args()
#     main(args)

class custom_args:
    def __init__(self):
        self.dataset = 'cifar-10'
        self.model = 'mobilenetV2'
        self.gated = True
        self.gatename = 'GateMLP'
        self.ranker = 'entropy'
        self.generator = False
        self.threshold = 0.6
        self.tb = 100

    def __str__(self):
        return 'dataset: '+self.dataset+', model: '+self.model+', gated: '+str(self.gated)+', ranker: '+self.ranker+', generator: '+str(self.generator) \
            + ', gatename: '+self.gatename + ', threshold: '+str(self.threshold) + ', test_batch: '+str(self.tb)

args = custom_args()
print(args)
main(args)

dataset: cifar-10, model: mobilenetV2, gated: True, ranker: entropy, generator: False, gatename: GateMLP, threshold: 0.6, test_batch: 100
Files already downloaded and verified
Files already downloaded and verified


100it [00:09, 10.42it/s]

The average confidence is:  0.8452839583158493
The average confidence for each gate is: 0.8392, 0.8452, 0.8515, 0.8291
The average accuracy is:  91.96
Gate chosen: 1, 36, 46, 17





In [67]:
# load the gate weight
weightpath = '/home/tonypeng/Workspace1/adaptfilter/Adaptfilter/Weights/cifar-10/gate/2024_06_14_17_46_36/GateMLP_0_2024_06_14_17_46_36.pth'
import torch
torch.load(weightpath)

OrderedDict([('linear1.weight',
              tensor([[-0.0119, -0.0055, -0.0026,  ..., -0.0123, -0.0007, -0.0031],
                      [ 0.0092, -0.0060, -0.0047,  ..., -0.0119, -0.0115, -0.0130],
                      [-0.0097, -0.0155, -0.0152,  ..., -0.0007, -0.0081, -0.0029],
                      ...,
                      [-0.0099, -0.0155, -0.0089,  ...,  0.0029,  0.0028, -0.0094],
                      [-0.0064, -0.0034,  0.0038,  ...,  0.0067,  0.0045, -0.0117],
                      [-0.0043, -0.0164, -0.0147,  ..., -0.0147,  0.0012, -0.0130]],
                     device='cuda:0')),
             ('linear1.bias',
              tensor([ 4.8422e-03,  4.0539e-03, -1.3591e-02,  9.9373e-04, -1.6840e-02,
                      -1.2041e-03, -1.3823e-02,  5.2071e-03,  2.7350e-03,  1.3404e-03,
                      -1.0567e-02, -6.7861e-04, -5.1796e-05,  4.4697e-03, -5.3810e-03,
                      -1.1258e-02, -2.0388e-03, -1.5965e-02,  5.2826e-03,  1.7817e-03,
                  