In [29]:
# this is the total pipeline for the project
# This file is for trainning
# Run this on the server, or as we called offline. 

# with gated, no generator
from Dataloaders.dataloader_cifar10 import Dataloader_cifar10
import argparse
import torch
import os
import sys
import numpy as np
import time
from Models import mobilenetv2
from Utils import utils
from Models import gatedmodel
from Models import generator

def main(args):
    # initial using mobilenetV2, and cifar10
    # we need a if statement here to decide which model and dataset to use
    # random_seed = 2024

    # for training, it is for training the generator 
    # recall the graph, when we cut more features, the performance should be worse.
    
    # get the training loader, freeze the model. Where is the partitioning point? 

    # 1. get the train, test and val datasets, and labels.
    if args.dataset == 'cifar10':
        # return train, test, val, labels, these are all dataloaders
        _, test, classes = Dataloader_cifar10(train_batch=128, test_batch=100, seed=2024)
    elif args.dataset == 'cifar100':
        pass
    
    # 2. transfer the dataset to fit the model, for the training, client and server model are all on the server
    if args.model == 'mobilenetV2':
        client_model, server_model = mobilenetv2.stupid_model_splitter(weight_path='./Weights/cifar-10/MobileNetV2.pth')
    elif args.model == 'resnet':
        pass

    # 3. get the gating, gating here decides how many channels are transferred to the server
    # simple version: a binary tree, complex version: model
    # get a ranker to rank the channels, and get the top k channels
    ranker = utils.ranker_entropy # input: embs, percentage, output: s_emb, s_ind

    # 4. get the gated model, we have 3 models here
    gated_rates = [0.25, 0.5, 0.75]
    channel2ind = {8:0, 16:1, 24:2, 32:3}
    input_channel = 32
    gateds = []
    for i in range(3):
        gateds.append(gatedmodel.GatedRegression(
            input_size=int(input_channel*gated_rates[i]),
            weight=32,
            height=32,
            output_size=10)) #  input_size, weight, height, output_size=10

    # 5. get the generator
    generators = []
    for i in range(3):
        generators.append(generator.Generator(
            inputsize=int(input_channel*gated_rates[i]), 
            hiddensize=32, 
            outputsize=32)) # inputsize, hiddensize, outputsize
        
    # 6. get the server 
    # server_model = some_model_function()
    # server model is got above

    # pipline data -> dataloader -> client_model -> gating -> reducer -> generator -> server_model
    
    # cuda may not have enough space for putting all the models
    # client_model = client_model.cuda()
    # server_model = server_model.cuda()
    # for i in range(3):
    #     gateds[i] = gateds[i].cuda()
    #     generators[i] = generators[i].cuda()

    # set them to eval
    client_model.eval()
    server_model.eval()
    for i in range(3):
        gateds[i].eval()
        generators[i].eval()

    globbal_threshold = 0.8 # set a small value first

    # load the test data set the test
    with torch.no_grad():
        correct = 0
        total = 0
        client_model = client_model.cuda()
        server_model = server_model.cuda()
        for data in test:
            images, labels = data
            images, labels = images.cuda(), labels.cuda()
            # first, run the client model on the iot
            out = client_model(images)
            # second, run the ranker
            exit_flag = args.gated
            counter = 0
            while exit_flag:
                # get the emb and ind from the ranker
                out = out.cpu()
                s_emb, s_ind = ranker(out, gated_rates[counter])
                s_emb = s_emb.cuda()
                # give it to the gated model
                cur_gated = gateds[counter]
                # load the weights
                cur_gated.load_state_dict(torch.load('./Weights/cifar-10/GatedRegression_'+str(counter)+'.pth'))
                cur_gated = cur_gated.cuda()
                g_emb = cur_gated(s_emb) # b, n
                # get the argmax
                g_conf = torch.max(g_emb, dim=1).values # b
                # use the sigmoid to get the confidence
                g_conf = torch.nn.functional.sigmoid(g_conf) # b
                # if in a batch, get the average
                g_conf = torch.mean(g_conf)
                # print('The confidence is: ', g_conf.item())
                if g_conf > globbal_threshold:
                    exit_flag = not exit_flag
                else:
                    counter += 1

            # check the exit flag and send
            if exit_flag != args.gated:
                out = s_emb

            # a sender here, but on server, we don't have it.
            
            # a receiver here, but on server, we don't have it.
            # get the generator
            rec_size = out.size(1)
            if exit_flag != args.gated:
                rec_ind = s_ind
            rec_size2ind = channel2ind[rec_size]
            # if we don't get all features, we need to use the generator

            if args.generator:
                if rec_size2ind != 3:
                    cur_gen = generators[rec_size2ind]
                    # load weights
                    cur_gen.load_state_dict(torch.load('./Weights/cifar-10/generator_'+str(rec_size2ind)+'.pth'))
                    cur_gen = cur_gen.cuda()
                    out = cur_gen(out)

            # skip the generator, create a tensor with all zeros
            else:
                if exit_flag != args.gated:
                    n_out = torch.zeros(out.size(0), 32, 32, 32).cuda()
                    n_out[:, rec_ind, :, :] = out
                    out = n_out

            # run the server model
            out = server_model(out)
            _, predicted = torch.max(out.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

# if __name__ == '__main__':
#     print('enter')
#     parser = argparse.ArgumentParser()
#     # we need the name of model, the name of dataset
#     parser.add_argument('--dataset', type=str, default='cifar10', help='name of dataset')
#     # parser.add_argument('--iot_model', type=str, default='mobilenetV2', help='name of the model on the iot')
#     parser.add_argument('--reducer', type=str, default='entrophy', help='name of the reducer')
#     parser.add_argument('--client', type=str, default='LTE', help='name of the network condition on the client side')
#     parser.add_argument('--server', type=str, default='LTE', help='name of the network condition on the server side')
#     parser.add_argument('--generator', type=str, default='None', help='name of the generator')
#     # parser.add_argument('--server_model', type=str, default='mobilenetV2', help='name of the model on the server, should be the same as it on the iot')
#     parser.add_argument('--device', type=str, default='home', help='run on which device, home, tintin, rpi, pico, jetson?')
#     parser.add_argument('--model', type=str, default='mobilenetV2', help='name of the model')
#     args = parser.parse_args()
#     main(args)

class custom_args:
    def __init__(self):
        self.dataset = 'cifar10'
        self.model = 'mobilenetV2'
        self.gated = False
        self.ranker = 'entropy'
        self.generator = False

    def __str__(self):
        return 'dataset: '+self.dataset+', model: '+self.model+', gated: '+str(self.gated)+', ranker: '+self.ranker+', generator: '+str(self.generator)

args = custom_args()
print(args)
main(args)

dataset: cifar10, model: mobilenetV2, gated: False, ranker: entropy, generator: False
Files already downloaded and verified
Files already downloaded and verified
Accuracy of the network on the 10000 test images: 92 %


In [31]:
# this is the total pipeline for the project
# This file is for trainning
# Run this on the server, or as we called offline. 

# with gated, no generator
from Dataloaders.dataloader_cifar10 import Dataloader_cifar10
import argparse
import torch
import os
import sys
import numpy as np
import time
from Models import mobilenetv2
from Utils import utils
from Models import gatedmodel
from Models import generator

def main(args):
    # initial using mobilenetV2, and cifar10
    # we need a if statement here to decide which model and dataset to use
    # random_seed = 2024

    # for training, it is for training the generator 
    # recall the graph, when we cut more features, the performance should be worse.
    
    # get the training loader, freeze the model. Where is the partitioning point? 

    # 1. get the train, test and val datasets, and labels.
    if args.dataset == 'cifar10':
        # return train, test, val, labels, these are all dataloaders
        _, test, classes = Dataloader_cifar10(train_batch=128, test_batch=100, seed=2024)
    elif args.dataset == 'cifar100':
        pass
    
    # 2. transfer the dataset to fit the model, for the training, client and server model are all on the server
    if args.model == 'mobilenetV2':
        client_model, server_model = mobilenetv2.stupid_model_splitter(weight_path='./Weights/cifar-10/MobileNetV2.pth')
    elif args.model == 'resnet':
        pass

    # 3. get the gating, gating here decides how many channels are transferred to the server
    # simple version: a binary tree, complex version: model
    # get a ranker to rank the channels, and get the top k channels
    ranker = utils.ranker_entropy # input: embs, percentage, output: s_emb, s_ind

    # 4. get the gated model, we have 3 models here
    gated_rates = [0.25, 0.5, 0.75]
    channel2ind = {8:0, 16:1, 24:2, 32:3}
    input_channel = 32
    gateds = []
    for i in range(3):
        gateds.append(gatedmodel.GatedRegression(
            input_size=int(input_channel*gated_rates[i]),
            weight=32,
            height=32,
            output_size=10)) #  input_size, weight, height, output_size=10

    # 5. get the generator
    generators = []
    for i in range(3):
        generators.append(generator.Generator(
            inputsize=int(input_channel*gated_rates[i]), 
            hiddensize=32, 
            outputsize=32)) # inputsize, hiddensize, outputsize
        
    # 6. get the server 
    # server_model = some_model_function()
    # server model is got above

    # pipline data -> dataloader -> client_model -> gating -> reducer -> generator -> server_model
    
    # cuda may not have enough space for putting all the models
    # client_model = client_model.cuda()
    # server_model = server_model.cuda()
    # for i in range(3):
    #     gateds[i] = gateds[i].cuda()
    #     generators[i] = generators[i].cuda()

    # set them to eval
    client_model.eval()
    server_model.eval()
    for i in range(3):
        gateds[i].eval()
        generators[i].eval()

    globbal_threshold = 0.8 # set a small value first

    # load the test data set the test
    with torch.no_grad():
        correct = 0
        total = 0
        client_model = client_model.cuda()
        server_model = server_model.cuda()
        for data in test:
            images, labels = data
            images, labels = images.cuda(), labels.cuda()
            # first, run the client model on the iot
            out = client_model(images)
            # second, run the ranker
            exit_flag = args.gated
            counter = 0
            while exit_flag:
                # get the emb and ind from the ranker
                out = out.cpu()
                s_emb, s_ind = ranker(out, gated_rates[counter])
                s_emb = s_emb.cuda()
                # give it to the gated model
                cur_gated = gateds[counter]
                # load the weights
                cur_gated.load_state_dict(torch.load('./Weights/cifar-10/GatedRegression_'+str(counter)+'.pth'))
                cur_gated = cur_gated.cuda()
                g_emb = cur_gated(s_emb) # b, n
                # get the argmax
                g_conf = torch.max(g_emb, dim=1).values # b
                # use the sigmoid to get the confidence
                g_conf = torch.nn.functional.sigmoid(g_conf) # b
                # if in a batch, get the average
                g_conf = torch.mean(g_conf)
                # print('The confidence is: ', g_conf.item())
                if g_conf > globbal_threshold:
                    exit_flag = not exit_flag
                else:
                    counter += 1

            # check the exit flag and send
            if exit_flag != args.gated:
                print('The choosen gate is: ', counter, 'The confidence is: ', g_conf.item(), 'The size of channel is: ', s_emb.size(1))
                out = s_emb

            # a sender here, but on server, we don't have it.
            
            # a receiver here, but on server, we don't have it.
            # get the generator
            rec_size = out.size(1)
            if exit_flag != args.gated:
                rec_ind = s_ind
            rec_size2ind = channel2ind[rec_size]
            # if we don't get all features, we need to use the generator

            if args.generator:
                if rec_size2ind != 3:
                    cur_gen = generators[rec_size2ind]
                    # load weights
                    cur_gen.load_state_dict(torch.load('./Weights/cifar-10/generator_'+str(rec_size2ind)+'.pth'))
                    cur_gen = cur_gen.cuda()
                    out = cur_gen(out)

            # skip the generator, create a tensor with all zeros
            else:
                if exit_flag != args.gated:
                    n_out = torch.zeros(out.size(0), 32, 32, 32).cuda()
                    n_out[:, rec_ind, :, :] = out
                    out = n_out

            # run the server model
            out = server_model(out)
            _, predicted = torch.max(out.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

# if __name__ == '__main__':
#     print('enter')
#     parser = argparse.ArgumentParser()
#     # we need the name of model, the name of dataset
#     parser.add_argument('--dataset', type=str, default='cifar10', help='name of dataset')
#     # parser.add_argument('--iot_model', type=str, default='mobilenetV2', help='name of the model on the iot')
#     parser.add_argument('--reducer', type=str, default='entrophy', help='name of the reducer')
#     parser.add_argument('--client', type=str, default='LTE', help='name of the network condition on the client side')
#     parser.add_argument('--server', type=str, default='LTE', help='name of the network condition on the server side')
#     parser.add_argument('--generator', type=str, default='None', help='name of the generator')
#     # parser.add_argument('--server_model', type=str, default='mobilenetV2', help='name of the model on the server, should be the same as it on the iot')
#     parser.add_argument('--device', type=str, default='home', help='run on which device, home, tintin, rpi, pico, jetson?')
#     parser.add_argument('--model', type=str, default='mobilenetV2', help='name of the model')
#     args = parser.parse_args()
#     main(args)

class custom_args:
    def __init__(self):
        self.dataset = 'cifar10'
        self.model = 'mobilenetV2'
        self.gated = True
        self.ranker = 'entropy'
        self.generator = False

    def __str__(self):
        return 'dataset: '+self.dataset+', model: '+self.model+', gated: '+str(self.gated)+', ranker: '+self.ranker+', generator: '+str(self.generator)

args = custom_args()
print(args)
main(args)

dataset: cifar10, model: mobilenetV2, gated: True, ranker: entropy, generator: False
Files already downloaded and verified
Files already downloaded and verified
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2

In [34]:
# this is the total pipeline for the project
# This file is for trainning
# Run this on the server, or as we called offline. 

# with gated, no generator
from Dataloaders.dataloader_cifar10 import Dataloader_cifar10
import argparse
import torch
import os
import sys
import numpy as np
import time
from Models import mobilenetv2
from Utils import utils
from Models import gatedmodel
from Models import generator

def main(args):
    # initial using mobilenetV2, and cifar10
    # we need a if statement here to decide which model and dataset to use
    # random_seed = 2024

    # for training, it is for training the generator 
    # recall the graph, when we cut more features, the performance should be worse.
    
    # get the training loader, freeze the model. Where is the partitioning point? 

    # 1. get the train, test and val datasets, and labels.
    if args.dataset == 'cifar10':
        # return train, test, val, labels, these are all dataloaders
        _, test, classes = Dataloader_cifar10(train_batch=128, test_batch=100, seed=2024)
    elif args.dataset == 'cifar100':
        pass
    
    # 2. transfer the dataset to fit the model, for the training, client and server model are all on the server
    if args.model == 'mobilenetV2':
        client_model, server_model = mobilenetv2.stupid_model_splitter(weight_path='./Weights/cifar-10/MobileNetV2.pth')
    elif args.model == 'resnet':
        pass

    # 3. get the gating, gating here decides how many channels are transferred to the server
    # simple version: a binary tree, complex version: model
    # get a ranker to rank the channels, and get the top k channels
    ranker = utils.ranker_entropy # input: embs, percentage, output: s_emb, s_ind

    # 4. get the gated model, we have 3 models here
    gated_rates = [0.25, 0.5, 0.75]
    channel2ind = {8:0, 16:1, 24:2, 32:3}
    input_channel = 32
    gateds = []
    for i in range(3):
        gateds.append(gatedmodel.GatedRegression(
            input_size=int(input_channel*gated_rates[i]),
            weight=32,
            height=32,
            output_size=10)) #  input_size, weight, height, output_size=10

    # 5. get the generator
    generators = []
    for i in range(3):
        generators.append(generator.Generator(
            inputsize=int(input_channel*gated_rates[i]), 
            hiddensize=32, 
            outputsize=32)) # inputsize, hiddensize, outputsize
        
    # 6. get the server 
    # server_model = some_model_function()
    # server model is got above

    # pipline data -> dataloader -> client_model -> gating -> reducer -> generator -> server_model
    
    # cuda may not have enough space for putting all the models
    # client_model = client_model.cuda()
    # server_model = server_model.cuda()
    # for i in range(3):
    #     gateds[i] = gateds[i].cuda()
    #     generators[i] = generators[i].cuda()

    # set them to eval
    client_model.eval()
    server_model.eval()
    for i in range(3):
        gateds[i].eval()
        generators[i].eval()

    globbal_threshold = 0.8 # set a small value first

    # load the test data set the test
    with torch.no_grad():
        correct = 0
        total = 0
        client_model = client_model.cuda()
        server_model = server_model.cuda()
        for data in test:
            images, labels = data
            images, labels = images.cuda(), labels.cuda()
            # first, run the client model on the iot
            out = client_model(images)
            # second, run the ranker
            exit_flag = args.gated
            counter = 0
            while exit_flag:
                # get the emb and ind from the ranker
                out = out.cpu()
                s_emb, s_ind = ranker(out, gated_rates[counter])
                s_emb = s_emb.cuda()
                # give it to the gated model
                cur_gated = gateds[counter]
                # load the weights
                cur_gated.load_state_dict(torch.load('./Weights/cifar-10/GatedRegression_'+str(counter)+'.pth'))
                cur_gated = cur_gated.cuda()
                g_emb = cur_gated(s_emb) # b, n
                # get the argmax
                g_conf = torch.max(g_emb, dim=1).values # b
                # use the sigmoid to get the confidence
                g_conf = torch.nn.functional.sigmoid(g_conf) # b
                # if in a batch, get the average
                g_conf = torch.mean(g_conf)
                # print('The confidence is: ', g_conf.item())
                if g_conf > globbal_threshold:
                    exit_flag = not exit_flag
                else:
                    counter += 1

            # check the exit flag and send
            if exit_flag != args.gated:
                print('The choosen gate is: ', counter)
                out = s_emb

            # a sender here, but on server, we don't have it.
            
            # a receiver here, but on server, we don't have it.
            # get the generator
            rec_size = out.size(1)
            if exit_flag != args.gated:
                rec_ind = s_ind
            rec_size2ind = channel2ind[rec_size]
            # if we don't get all features, we need to use the generator
            if args.generator:
                if rec_size2ind != 3:
                    cur_gen = generators[rec_size2ind]
                    # load weights
                    cur_gen.load_state_dict(torch.load('./Weights/cifar-10/generator_'+str(rec_size2ind)+'.pth'))
                    cur_gen = cur_gen.cuda()
                    out = cur_gen(out)

            # skip the generator, create a tensor with all zeros
            else:
                if exit_flag != args.gated:
                    n_out = torch.zeros(out.size(0), 32, 32, 32).cuda()
                    n_out[:, rec_ind, :, :] = out
                    out = n_out

            # run the server model
            out = server_model(out)
            _, predicted = torch.max(out.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

# if __name__ == '__main__':
#     print('enter')
#     parser = argparse.ArgumentParser()
#     # we need the name of model, the name of dataset
#     parser.add_argument('--dataset', type=str, default='cifar10', help='name of dataset')
#     # parser.add_argument('--iot_model', type=str, default='mobilenetV2', help='name of the model on the iot')
#     parser.add_argument('--reducer', type=str, default='entrophy', help='name of the reducer')
#     parser.add_argument('--client', type=str, default='LTE', help='name of the network condition on the client side')
#     parser.add_argument('--server', type=str, default='LTE', help='name of the network condition on the server side')
#     parser.add_argument('--generator', type=str, default='None', help='name of the generator')
#     # parser.add_argument('--server_model', type=str, default='mobilenetV2', help='name of the model on the server, should be the same as it on the iot')
#     parser.add_argument('--device', type=str, default='home', help='run on which device, home, tintin, rpi, pico, jetson?')
#     parser.add_argument('--model', type=str, default='mobilenetV2', help='name of the model')
#     args = parser.parse_args()
#     main(args)

class custom_args:
    def __init__(self):
        self.dataset = 'cifar10'
        self.model = 'mobilenetV2'
        self.gated = True
        self.ranker = 'entropy'
        self.generator = True

    def __str__(self):
        return 'dataset: '+self.dataset+', model: '+self.model+', gated: '+str(self.gated)+', ranker: '+self.ranker+', generator: '+str(self.generator)

args = custom_args()
print(args)
main(args)

dataset: cifar10, model: mobilenetV2, gated: True, ranker: entropy, generator: True
Files already downloaded and verified
Files already downloaded and verified
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
The choosen gate is:  2
