In [4]:
# load the weight in pretrained mobilnetv2
import torch
weight = torch.load('./Weights/imagenet/pretrained/pretrained_mobilenetv2.pth')

In [7]:
# get the resnet 18 and resnet 50, 152
from torchvision import models
resnet18 = models.resnet18(pretrained=models.ResNet18_Weights.IMAGENET1K_V1)
resnet50 = models.resnet50(pretrained=models.ResNet50_Weights.IMAGENET1K_V2)
resnet152 = models.resnet152(pretrained=models.ResNet152_Weights.IMAGENET1K_V2)
mobilenetv2 = models.mobilenet_v2(pretrained=models.MobileNet_V2_Weights.IMAGENET1K_V2)
mobilenetv3 = models.mobilenet_v3_large(pretrained=models.MobileNet_V3_Large_Weights.IMAGENET1K_V2)

# store these weights
torch.save(resnet18.state_dict(), './Weights/imagenet/pretrained/resnet18.pth')
torch.save(resnet50.state_dict(), './Weights/imagenet/pretrained/resnet50.pth')
torch.save(resnet152.state_dict(), './Weights/imagenet/pretrained/resnet152.pth')
torch.save(mobilenetv2.state_dict(), './Weights/imagenet/pretrained/mobilenetv2.pth')
torch.save(mobilenetv3.state_dict(), './Weights/imagenet/pretrained/mobilenetv3.pth')

# 

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/tonypeng/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:03<00:00, 34.1MB/s]
Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /home/tonypeng/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth
100%|██████████| 230M/230M [00:07<00:00, 34.5MB/s] 
Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /home/tonypeng/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth
100%|██████████| 21.1M/21.1M [00:00<00:00, 31.4MB/s]


In [32]:
from Models.mobilenetv2_original import MobileNetV2_client, MobileNetV2_server
import os 

def stupid_model_splitter(num_classes = 1000, weight_root = '', device = 'cuda:0', partition = -1):
    # here are have a very stupid splitter for 
    # the restnet101 mode
    # assert that they have at least the same length of models

    c_model = MobileNetV2_client()
    s_model = MobileNetV2_server(num_classes = num_classes)

    c_weight = c_model.state_dict()
    s_weight = s_model.state_dict()
    c_weight_key = list(c_weight.keys())
    s_weight_key = list(s_weight.keys())

    if partition == -1:
        partition = len(c_weight_key)

    cw_path = weight_root + '/client/mobilenetv2.pth'
    sw_path = weight_root + '/server/mobilenetv2.pth'
    if not os.path.exists(cw_path):
        pw_path = weight_root + '/pretrained/mobilenetv2.pth'
        in_weight = torch.load(pw_path, map_location=device)
        assert (len(c_weight_key) + len(s_weight_key) == len(in_weight))

        # reivese the key of weights
        in_weight_keys = list(in_weight.keys())
        for i in range(len(in_weight_keys)):
            if i < partition:
                c_weight[c_weight_key[i]] = in_weight[in_weight_keys[i]]
            else:
                s_weight[s_weight_key[i-partition]] = in_weight[in_weight_keys[i]]
        # store the weights
        if not os.path.exists(weight_root + '/client'):
            os.makedirs(weight_root + '/client')
        if not os.path.exists(weight_root + '/server'):
            os.makedirs(weight_root + '/server')

        torch.save(c_weight, './Weights/imagenet/client/mobilenetv2.pth')
        torch.save(s_weight, './Weights/imagenet/server/mobilenetv2.pth')
    
    c_model.load_state_dict(torch.load(cw_path, map_location=device))
    s_model.load_state_dict(torch.load(sw_path, map_location=device))
    return c_model, s_model

In [44]:
# load the weight in pretrained imagenet
import torch
mobilenetv2 = torch.load('./Weights/imagenet/pretrained/mobilenetv2.pth')
print(mobilenetv2.keys())

# get the mobilenetv2 client and server
client, server = stupid_model_splitter(weight_root = './Weights/imagenet', partition = 3)
print(client)
print(server)


odict_keys(['features.0.0.weight', 'features.0.1.weight', 'features.0.1.bias', 'features.0.1.running_mean', 'features.0.1.running_var', 'features.0.1.num_batches_tracked', 'features.1.conv.0.0.weight', 'features.1.conv.0.1.weight', 'features.1.conv.0.1.bias', 'features.1.conv.0.1.running_mean', 'features.1.conv.0.1.running_var', 'features.1.conv.0.1.num_batches_tracked', 'features.1.conv.1.weight', 'features.1.conv.2.weight', 'features.1.conv.2.bias', 'features.1.conv.2.running_mean', 'features.1.conv.2.running_var', 'features.1.conv.2.num_batches_tracked', 'features.2.conv.0.0.weight', 'features.2.conv.0.1.weight', 'features.2.conv.0.1.bias', 'features.2.conv.0.1.running_mean', 'features.2.conv.0.1.running_var', 'features.2.conv.0.1.num_batches_tracked', 'features.2.conv.1.0.weight', 'features.2.conv.1.1.weight', 'features.2.conv.1.1.bias', 'features.2.conv.1.1.running_mean', 'features.2.conv.1.1.running_var', 'features.2.conv.1.1.num_batches_tracked', 'features.2.conv.2.weight', 'fea

In [3]:
from Models import mobilenetv2, mobilenetv3, resnet
mv2 = mobilenetv2.mobilenetv2_splitter()
mv3 = mobilenetv3.mobilenetv3_splitter()
rn = resnet.resnet_splitter(layers=152)

6
926
932


In [43]:
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F 
total_model = torchvision.models.mobilenet_v2()
class t_model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.part1 = torch.nn.Sequential(nn.Linear(10,20))
        self.part2 = torch.nn.Sequential(nn.Linear(20,40))
        self.part3 = torch.nn.Sequential(nn.Linear(40,50))
    def forward(self, x):
        out = self.part1(x)
        out = F.relu(out)
        out = self.part2(out)
        out = F.avg_pool2d(out, 2)
        out = self.part3(out)
        return out
client = t_model()
# client is the first 3 layer of total model
client = torch.nn.Sequential(*list(t_model().children())[:3])
print(client)

Sequential(
  (0): Sequential(
    (0): Linear(in_features=10, out_features=20, bias=True)
  )
  (1): Sequential(
    (0): Linear(in_features=20, out_features=40, bias=True)
  )
  (2): Sequential(
    (0): Linear(in_features=40, out_features=50, bias=True)
  )
)
