In [1]:
import numpy as np
import grpc
import proto.cbsp_pb2 as cbsp_pb2
import proto.cbsp_pb2_grpc as cbsp_pb2_grpc
from concurrent import futures
import pickle
import torch

In [2]:
class ClientMessage():
    def __init__(self, CbspMsg):
        self.CbspMsg = CbspMsg
    
    # Note: All the serialization and deserialization use Lazy Dict serialization!
    # Utility Functions
    def torchModel2NumpyParams(self, model): 
        params = [val.cpu().numpy() for _, val in model.state_dict().items()]
        return params

    def numpyParams2TorchModel(self, model, params): 
        model_params = model.state_dict()
        for key, val in model_params.items():
            if len(params) > 0:
                model_params[key] = torch.tensor(params.pop(0))
        model.load_state_dict(model_params)
        return model
                        
    # Orphaned Function
    def serializeDictStrict(self, info): 
        info_ = {}
        for key, value in info.items():
            assert (isinstance(key, str)), "Dictionary Keys contain non String values!"
            assert (isinstance(value, int) or isinstance(value, str) or isinstance(value, float) or isinstance(value, bool)), "Dictionary Contains Unsupported Value types!"
            
            value_ = None
            if(isinstance(value, bool)):    
                value_ = self.CbspMsg.Constant(bool=value)
            elif(isinstance(value, int)):    
                value_ = self.CbspMsg.Constant(sint64=value)
            elif(isinstance(value, float)):    
                value_ = self.CbspMsg.Constant(double=value)
            elif(isinstance(value, str)):    
                value_ = self.CbspMsg.Constant(string=value)
            else:
                print("UNKNOWN ERROR CONVERTING DICT TO GRPC MSG FORMAT :( (chusko)")
                
            info_[key] = value_
        return info_
    
    def serializeDictLazy(self, info):
        return {"dict":self.CbspMsg.Constant(string=str(info))}
        
    def deserializeDictLazy(self, info):
        return eval(info['dict'].string)

    # Msg Serialization Functions
    def serializePytorchParams(self, params): 
        params_bytelist = []
        for i in params:
            params_bytelist.append(cbsp_pb2.ParameterBytes(tensor=pickle.dumps(i), shape=[0]))
        return cbsp_pb2.PytorchParameters(parameters=params_bytelist, dtype=str(params[0].dtype))
                      
    def serializeGetParametersMsg(self, info):
        info_=self.serializeDictLazy(info)
        return self.CbspMsg.ClientMessage(
            get_parameters = self.CbspMsg.ClientMessage.GetParameters(
                type=self.CbspMsg.ClientMessage.GET_PARAMETERS,
                info=info_
            )
        )
        
    def serializeGetConfigMsg(self, info):
        info_=self.serializeDictLazy(info)
        return self.CbspMsg.ClientMessage(
            get_config = self.CbspMsg.ClientMessage.GetConfig(
                type=self.CbspMsg.ClientMessage.GET_CONFIG,
                info=info_
            )
        )
        
    def serializeSendParametersMsg(self, info, model):
        info_=self.serializeDictLazy(info)
        params = self.torchModel2NumpyParams(model)
        params_grpc=self.serializePytorchParams(params) # params_grpc
        return self.CbspMsg.ClientMessage(
            send_parameters = self.CbspMsg.ClientMessage.SendParameters(
                type=self.CbspMsg.ClientMessage.SEND_PARAMETERS,
                info=info_,
                parameters=params_grpc
            )
        )

    def serializeSendResultsMsg(self, info, results):
        info_=self.serializeDictLazy(info)
        results_=self.serializeDictLazy(results)
        return self.CbspMsg.ClientMessage(
            send_results = self.CbspMsg.ClientMessage.SendResults(
                type=self.CbspMsg.ClientMessage.SEND_RESULTS,
                info=info_,
                results=results_
            )
        )
    
    # Msg Deserialization Functions
    def deserializePytorchParams(self, params_grpc): 
        params_bytelist = params_grpc.parameters
        params_dtype = params_grpc.dtype # Redundant
        params = []
        for i in params_bytelist:
            params.append(pickle.loads(i.tensor))
        return params
        
    def deserializeGetParametersMsg(self, request):
        info_ = request.get_parameters.info

        info = self.deserializeDictLazy(info_)

        return info
        
    def deserializeGetConfigMsg(self, request):
        info_ = request.get_config.info

        info = self.deserializeDictLazy(info_)

        return info
        
    def deserializeSendParametersMsg(self, request, model=None):
        info_ = request.send_parameters.info
        params_grpc = request.send_parameters.parameters
        
        info = self.deserializeDictLazy(info_)
        params = self.deserializePytorchParams(params_grpc)

        if(model):
            model = self.numpyParams2TorchModel(model, params)

        return info, params, model # dtype process is sus

    def deserializeSendResultsMsg(self, request):
        info_ = request.send_results.info
        results_ = request.send_results.results

        info = self.deserializeDictLazy(info_)
        results = self.deserializeDictLazy(results_)

        return info, results

In [3]:
class ServerMessage():
    def __init__(self, CbspMsg):
        self.CbspMsg = CbspMsg
        
    # TODO: Reduce redundancy by separating/modularizing utility functions code
    # Note: All the serialization and deserialization use Lazy Dict serialization!
    # Utility Functions
    def torchModel2NumpyParams(self, model): 
        params = [val.cpu().numpy() for _, val in model.state_dict().items()]
        return params

    def numpyParams2TorchModel(self, model, params): 
        model_params = model.state_dict()
        for key, val in model_params.items():
            if len(params) > 0:
                model_params[key] = torch.tensor(params.pop(0))
        model.load_state_dict(model_params)
        return model
        
    # Orphaned Function
    def serializeDictStrict(self, info): 
        info_ = {}
        for key, value in info.items():
            assert (isinstance(key, str)), "Dictionary Keys contain non String values!"
            assert (isinstance(value, int) or isinstance(value, str) or isinstance(value, float) or isinstance(value, bool)), "Dictionary Contains Unsupported Value types!"
            
            value_ = None
            if(isinstance(value, bool)):    
                value_ = self.CbspMsg.Constant(bool=value)
            elif(isinstance(value, int)):    
                value_ = self.CbspMsg.Constant(sint64=value)
            elif(isinstance(value, float)):    
                value_ = self.CbspMsg.Constant(double=value)
            elif(isinstance(value, str)):    
                value_ = self.CbspMsg.Constant(string=value)
            else:
                print("UNKNOWN ERROR CONVERTING DICT TO GRPC MSG FORMAT :( (chusko)")
                
            info_[key] = value_
        return info_
    
    def serializeDictLazy(self, info):
        return {"dict":self.CbspMsg.Constant(string=str(info))}
        
    def deserializeDictLazy(self, info):
        return eval(info['dict'].string)

    # Msg Serialization Functions        
    def serializePytorchParams(self, params): 
        params_bytelist = []
        for i in params:
            params_bytelist.append(cbsp_pb2.ParameterBytes(tensor=pickle.dumps(i), shape=[0]))
        return cbsp_pb2.PytorchParameters(parameters=params_bytelist, dtype=str(params[0].dtype))
   
    def serializeSendParametersMsg(self, info, model): # TODO: Get Better Naming for these since pytorch specific
        info_=self.serializeDictLazy(info)
        params = self.torchModel2NumpyParams(model)
        params_grpc=self.serializePytorchParams(params) # params_grpc
        return self.CbspMsg.ServerMessage(
            get_parameters = self.CbspMsg.ServerMessage.SendParameters( # TODO: Change this get_parameters to send_parameters
                type=self.CbspMsg.ServerMessage.SEND_PARAMETERS,
                info=info_,
                parameters=params_grpc
            )
        )       

    def serializeSendConfigMsg(self, info):
        info_=self.serializeDictLazy(info)
        return self.CbspMsg.ServerMessage(
            send_config = self.CbspMsg.ServerMessage.SendConfig(
                type=self.CbspMsg.ServerMessage.SEND_CONFIG,
                info=info_
            )
        )
        
    def serializeNormalResponseMsg(self, info, response):
        info_=self.serializeDictLazy(info)
        return self.CbspMsg.ServerMessage(
            normal_response = self.CbspMsg.ServerMessage.NormalResponse(
                type=self.CbspMsg.ServerMessage.MESSAGE_TYPE.NORMAL_RESPONSE,
                info=info_,
                response=response
            )
        )
        
    # Msg Deserialization Functions
    def deserializePytorchParams(self, params_grpc): 
        params_bytelist = params_grpc.parameters
        params_dtype = params_grpc.dtype # Redundant
        params = []
        for i in params_bytelist:
            params.append(pickle.loads(i.tensor))
        return params
        
    def deserializeSendParametersMsg(self, request, model=None):
        info_ = request.get_parameters.info # TODO: Change this get_parameters to send_parameters
        param_bytes = request.get_parameters.parameters # TODO: Change this get_parameters to send_parameters
        
        info = self.deserializeDictLazy(info_)
        params = self.deserializePytorchParams(param_bytes)

        if(model):
            model = self.numpyParams2TorchModel(model, params)

        return info, params, model # dtype process is sus

    def deserializeSendConfigMsg(self, request):
        info_ = request.send_config.info
        info = self.deserializeDictLazy(info_)
        return info
        
    def deserializeNormalResponseMsg(self, request):
        info_ = request.normal_response.info
        info = self.deserializeDictLazy(info_)
        response = request.normal_response.response
        return info, response

In [4]:
import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)

Using cache found in C:\Users\Admin/.cache\torch\hub\pytorch_vision_v0.10.0


# Testing Parameter Receiving

In [5]:
import torch
from torchvision import models, transforms
from PIL import Image
import requests

# Define image transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load and preprocess the test image
image_locn = "./wombat.jpg"
image = Image.open(image_locn).convert("RGB")
input_image = transform(image).unsqueeze(0)

# Load labels for ImageNet classes
LABELS_URL = "https://github.com/anishathalye/imagenet-simple-labels/blob/master/imagenet-simple-labels.json"
labels = requests.get(LABELS_URL).json()['payload']['blob']['rawLines']

In [6]:
# Load pre-trained ResNet-18 model
model = models.resnet18(pretrained=True)
model.eval()

# Perform inference
with torch.no_grad():
    output = model(input_image)

In [7]:
# Get the top 5 predictions
_, indices = torch.topk(output, 5)
probs = torch.nn.functional.softmax(output, dim=1)[0] * 100

# Print the top 5 classes and their probabilities
for i in range(5):
    print(f"{labels[indices[0][i]]}: {probs[indices[0][i]].item():.2f}%")

"wombat",: 97.59%
"beaver",: 2.00%
"marmot",: 0.26%
"cottontail rabbit",: 0.04%
"otter",: 0.02%


In [8]:
def recv_param_test(model):
    model.eval()
    with torch.no_grad():
        output = model(input_image)
    # Get the top 5 predictions
    _, indices = torch.topk(output, 5)
    probs = torch.nn.functional.softmax(output, dim=1)[0] * 100
    
    # Print the top 5 classes and their probabilities
    for i in range(5):
        print(f"{labels[indices[0][i]]}: {probs[indices[0][i]].item():.2f}%")

In [9]:
import numpy as np
import grpc
import proto.cbsp_pb2 as cbsp_pb2
import proto.cbsp_pb2_grpc as cbsp_pb2_grpc
from concurrent import futures

class CommunicationService(cbsp_pb2_grpc.CommunicationServiceServicer):
    # As per docs, this class is not supposed to have init function, I added this just so I can add the cbsp_pb2
    # as class' internal variable and also ClientManager as internal object, consider removing if unexplained errors persist 
    def __init__(self, cbsp_pb2):
        self.CbspMsg = cbsp_pb2
        self.cm = ClientMessage(self.CbspMsg)
        self.sm = ServerMessage(self.CbspMsg)
        
    def BidirectionalStream(self, request_iterator, context):        
        # print(hasattr(request_iterator,'send_results')) # Warning: Doesn't work - A grpc msg has attributes for all msg types, just that rest of them are empty
        # Check the type of message received and process accordingly
        if(request_iterator.WhichOneof("client_message") == "get_parameters"):
            print("Received GetParameters message")
            info = self.cm.deserializeGetParametersMsg(request_iterator)
            print(info)
        elif(request_iterator.WhichOneof("client_message") == "get_config"):
            print("Received GetConfig message")
            info = self.cm.deserializeGetConfigMsg(request_iterator)
            print(info)
        elif(request_iterator.WhichOneof("client_message") == "send_parameters"):
            print("Received SendParameters message")
            model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
            info, params, model = self.cm.deserializeSendParametersMsg(request_iterator, model)
            print(info,len(params),model)
            recv_param_test(model)
            # info, param_float, param_dtype, model = cm.deserializeSendParametersMsg(request_iterator, model) # Alternate way to get updated model as well
        elif(request_iterator.WhichOneof("client_message") == "send_results"):
            print("Received SendResults message")
            info, results = self.cm.deserializeSendResultsMsg(request_iterator)
            print(info, results)
        else:
            print("ERROR: Received unknown message type")
        model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
        server_response = self.sm.serializeSendParametersMsg({'aa':'moddale'}, model)
        return server_response

def run_server():
    options = [
        ('grpc.max_receive_message_length', 1024 * 1024 * 1000)  # Adjust the size as needed
    ]
    server = grpc.server(futures.ThreadPoolExecutor(), options=options)
    cbsp_pb2_grpc.add_CommunicationServiceServicer_to_server(CommunicationService(cbsp_pb2), server)
    
    server.add_insecure_port('[::]:50052')
    server.start()
    print('Server Started...')
    server.wait_for_termination()

if __name__ == '__main__':
    run_server()

Server Started :D


Received SendParameters message


Using cache found in C:\Users\Admin/.cache\torch\hub\pytorch_vision_v0.10.0


{'seed': '69'} 0 ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU

Using cache found in C:\Users\Admin/.cache\torch\hub\pytorch_vision_v0.10.0
