In [4]:
# first set up the and run the kafka server also install kafka-python. The commands can be found in 'kakfa commands.txt'

from kafka import KafkaAdminClient
from kafka.admin import NewTopic

# Create an instance of KafkaAdminClient
admin_client = KafkaAdminClient(
    bootstrap_servers="localhost:9092",  # Update with your Kafka broker(s)
    client_id='my_client'
)

topic_name = '3DOF_old'

num_partitions = 2
replication_factor = 1
topic_configs = {"retention.ms": str(600000), # all messages that are older than 10 min in the topic get deleted
                 "retention.bytes": str(int(100e6)), #when the messages in the topic exide 100 MB old messages get deleted 
                 'max.message.bytes': str(int(5e6))} #max allowed size in MB of each message set to 5MB 

my_first_topic = NewTopic(name=topic_name, num_partitions=num_partitions,
                           replication_factor=replication_factor,
                           topic_configs=topic_configs)

admin_client.create_topics(new_topics=[my_first_topic])

# List all topics
topics = admin_client.list_topics()
print("Topics:", topics)




Topics: ['3DOF', '3DOF_old', 'my_first_topic']


In [5]:
from kafka import KafkaConsumer, KafkaProducer,TopicPartition
import msgpack
import simplejpeg
import torch


def dict_to_bytes(data_dict):
    return msgpack.packb(data_dict, use_bin_type=True)


def bytes_to_dict(json_str):
    return msgpack.unpackb(json_str, raw=False)


def encode_pytorch_image(img):
    """
    Takes a pytorch tensor of shape (3,H,W) and encodes it to bytes, jpeg compression is used 
    img is float32 and values are in range (0,1)
    """
    img_numpy = img * 255.0 
    img_numpy = torch.permute(img_numpy, (1,2,0)).contiguous()
    img_numpy = img_numpy.to(torch.uint8).numpy()
    grd_img_bytes = simplejpeg.encode_jpeg(img_numpy)

    return grd_img_bytes

def decode_pytorch_image(img_bytes):
    """
    Takes an encoded pytorch images that is in the form of bytes and decodes it 
    return tensor of shape (3,H,W)
    """
    img_numpy = simplejpeg.decode_jpeg(img_bytes)
    img_torch = torch.tensor(img_numpy, dtype=torch.float32) / 255.0
    img_torch = torch.permute(img_torch, [2,0,1]).contiguous()

    return img_torch


consumer = KafkaConsumer(#topic = topic_name,
                         bootstrap_servers=['localhost:9092'],
                         auto_offset_reset='latest', #will start consuming from the last message in the topic 
                         #consumer_timeout_ms=10000 #stop consumer from waiting for messages after 1000ms of not reciving any messages 
                         )

tp1 = TopicPartition(topic_name, 0)
consumer.assign([tp1])


producer = KafkaProducer(bootstrap_servers=['localhost:9092'],
                         )

In [6]:
import os

import utils
from Ford_dataset_s import SatGrdDatasetFordPresentation, train_logs, train_logs_img_inds, test_logs, test_logs_img_inds
from models_ford_s import ModelFord
from utils_s import render_point_cloud
# from models_ford import ModelFord as ModelFord_orig
from torchvision import transforms
import torch
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt 
import matplotlib.gridspec as gridspec
import numpy as np 
# from VGG import Unet, space2channel

from train_ford_3DOF_s import parse_args

args = parse_args()
# save_path ='ModelsFord/3DoF/Log_3lat20.0m_lon20.0m_rot10.0_Nit1_CrossAttn_FL_SL_3D_Uncertainty'
save_path = 'Log_3lat20.0m_lon20.0m_rot10.0_Nit1_CrossAttn_FL_SL_3D_Uncertainty'
args.train_log_start = 3

# cameras = ['FL', 'SL'] 
# args.image_H = 256 #256
# args.image_W = 1024#1024
# args.cameras = cameras
# args.batch_size = 1
# args.lifting = '3D' #homography

test_set = SatGrdDatasetFordPresentation(logs=test_logs[args.train_log_start:args.train_log_start+1],
                                logs_img_inds=test_logs_img_inds[args.train_log_start:args.train_log_start+1],
                                shift_range_lat=args.shift_range_lat, shift_range_lon=args.shift_range_lon,
                                rotation_range=args.rotation_range, whole=args.test_whole, 
                                H = args.image_H, W = args.image_W, cameras=args.cameras, mode='train') 
# testloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, pin_memory=True,
#                             num_workers=2, drop_last=False)

device = 'cuda' #torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
args.Rs = {key:item.to(device) for key, item in test_set.Rs.items()}
args.Ts = {key:item.to(device) for key, item in test_set.Ts.items()}
args.Ks = {key:item.to(device) for key, item in test_set.Ks.items()}
net = ModelFord(args).to(device)
net.load_state_dict(torch.load(os.path.join(save_path, 'model_1.pth')), strict=False)

len(test_set), args.batch_size

Error importing huggingface_hub.hf_api: cannot import name 'COMMON_SAFE_ASCII_CHARACTERS' from 'charset_normalizer.constant' (/home/dimitris/miniconda3/envs/condapy310/lib/python3.10/site-packages/charset_normalizer/constant.py)
/mnt/c/Users/dimitris/Desktop/MyFiles/diplomatikh/pytorch-tensorflow/pytorch/Cross-View-Localization/server client with kafka


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


(3511, 1)

In [7]:
torch.set_float32_matmul_precision('high') 
net = torch.compile(net)

In [8]:

for message in consumer: #consumer will bring all events from the start end then wait for the next event to happend untill it time outs after 1000ms 
    print('data recived')
    data_recived = bytes_to_dict(message.value)
    image_0_bytes = data_recived['image_0']; image_0_tensor = decode_pytorch_image(image_0_bytes).unsqueeze(0).to(device)
    image_1_bytes = data_recived['image_1']; image_1_tensor = decode_pytorch_image(image_1_bytes).unsqueeze(0).to(device)
    grd_imgs = [image_0_tensor, image_1_tensor]
    data_id = data_recived['data_id']


    data_local = test_set[data_id]
    sat_img, _, gt_shift_u, gt_shift_v, gt_heading, grd_names, sat_img_norot_notran, s_lat, s_lon, g_lat, g_lon, yaw = [[camera.to(device).unsqueeze(0) if type(camera) == torch.Tensor else camera for camera in item] if type(item)== tuple else item.to(device).unsqueeze(0) if type(item) == torch.Tensor else item for item in data_local]

    
    with torch.no_grad():
        pred_u, pred_v, pred_orien = net.CrossAttn_rot_corr(sat_img, grd_imgs, gt_shift_u, gt_shift_v, gt_heading, mode='test')
        
        data = {'pred_u': pred_u.cpu().item(), 'pred_v': pred_v.cpu().item(), 'pred_orien': pred_orien.cpu().item()}
        data_bytes_send = dict_to_bytes(data)
        #send data back 
        producer.send(topic_name, value=data_bytes_send, partition=1)
        print('data send')


data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send
data recived
data send


In [3]:
admin_client.delete_topics(topics=[topic_name])

DeleteTopicsResponse_v3(throttle_time_ms=0, topic_error_codes=[(topic='3DOF_old', error_code=0)])