In [1]:
import torch
from torchvision import transforms as tfs
from torch.utils.data import DataLoader
from utils import hyperparameters
from dataset import *
from generator import Generator
from communications import TCPClient
import socket

In [2]:
basic_path = '/media/felpipe/Archivos HDD/SocLab/'
train_path = basic_path + 'train'
valid_path = basic_path + 'valid'
test__path = basic_path + 'test'

seq_len = max([dataset_explore(train_path), dataset_explore(valid_path), dataset_explore(test__path)])

netparams = hyperparameters(w=320, 
                            h=239, 
                            latent_dim=512, 
                            history_length=8, 
                            future_length=12,
                            enc_layers=2,
                            lstm_dim=512,
                            output_dim=2,
                            up_criterion=0.9,
                            down_criterion=0.0,
                            seq_len=seq_len,
                            attention='add')

image_transforms = tfs.Compose([tfs.Resize((320, 239)),
                               tfs.ToTensor(),
                               tfs.Normalize([0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5])])

# test__set = RobotDataset(test__path, 512, seq_len, image_transforms)

# test__loader = DataLoader(test__set, batch_size=8, shuffle=False)

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
gen = Generator(netparams, device).to(device)
checkpoint = torch.load('gen_Glre3Dlre4_64btch_120ep.pth')
gen.load_state_dict(checkpoint)
a = gen.eval()
# for batch in test__loader:
#     break

In [4]:
HEADERSIZE = 16
TEXTCODING = 'utf-8'
BUFFERREAD = 1024
PORT = 8021

s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) #Socket Family and socket type
s.bind((socket.gethostname(), PORT)) #IP direction and port
s.listen(5)             

In [5]:
clientsocket, address = s.accept()
print("Conection from", address, "has been established!")
client = TCPClient(clientsocket, HEADERSIZE, TEXTCODING, BUFFERREAD)
client.Handshake()
waitRoute = torch.zeros((1, 12, 3))

while client.connected:
    
    imgs, routes, objective = client.Receive()
    if len(imgs) == 0 or len(routes) == 0 or len(objective) == 0:
        client.Send(waitRoute)
        continue
    
    test_set = OnlineProcessing(imgs, 
                                routes, 
                                objective, 
                                netparams['history'], 
                                netparams['predict_seq'], 
                                netparams['latent_dim'], 
                                image_transforms)
    
    test_loader = DataLoader(test_set, batch_size=1, shuffle=False)
    
    for imgs_t, z_t, past_routes_t, objective_t in test_loader:
        imgs_t = imgs_t.to(device)
        z_t = z_t.to(device)
        past_routes_t = get_relative(past_routes_t)
        past_routes_t = past_routes_t.to(device)
        objective_t = objective_t.to(device)
        predicted_route = gen(imgs_t, z_t, past_routes_t, objective_t)
        
    client.Send(predicted_route)


Conection from ('127.0.0.1', 56170) has been established!
Handshake successful


UnicodeDecodeError: 'utf-8' codec can't decode byte 0x89 in position 0: invalid start byte

In [None]:
import matplotlib.pyplot as plt
import numpy as np
for img in imgs:
    plt.figure()
    plt.imshow(np.asarray(img))

In [6]:
import copy
import io

In [7]:
full_msg = copy.copy(client.last_msg)

In [8]:
imgs = []
while True:
    byteimg = b''
    header = full_msg[:HEADERSIZE]
    if header == b'':
        break
    imglen = int(header.decode(TEXTCODING))
    byteimg = full_msg[HEADERSIZE:(HEADERSIZE + imglen)]
    full_msg = full_msg[(HEADERSIZE + imglen):]
    if byteimg.find(b'<SON>') == -1 and byteimg.find(b'<EOC>') == -1:
        img = Image.open(io.BytesIO(byteimg))                
        imgs.append(img)        
    else:
        break

In [14]:
header = full_msg[:HEADERSIZE]
trajlen = int(header.decode(TEXTCODING))
traj = full_msg[HEADERSIZE:(HEADERSIZE + trajlen)].decode(TEXTCODING)
full_msg = full_msg[(HEADERSIZE + trajlen):]


UnicodeDecodeError: 'utf-8' codec can't decode byte 0x89 in position 0: invalid start byte

In [17]:
full_msg = full_msg[(HEADERSIZE + 135483):]
full_msg

b'              23Numbers initiating<SON>             196-1.5244;-0.6314;282.9962|-1.5705;-0.5200;293.5513|-1.6265;-0.3850;303.7834|-1.6892;-0.2336;314.8340|-1.7472;-0.0937;321.7007|-1.8180;0.0774;327.5520|-1.8854;0.2400;339.8716|-1.8231;0.3729;350.1453             14326.2159;-282.9962|25.5892;-293.5513|24.8683;-303.7834|24.1109;-314.8340|23.4586;-321.7007|22.7235;-327.5520|22.0884;-339.8716|20.8219;-350.1453              24Stop stream message<EOS>'

In [16]:
byteimg = b''
header = full_msg[:HEADERSIZE]
imglen = int(header.decode(TEXTCODING))
byteimg = full_msg[HEADERSIZE:(HEADERSIZE + imglen)]
full_msg = full_msg[(HEADERSIZE + imglen):]
if byteimg.find(b'<SON>') == -1 and byteimg.find(b'<EOC>') == -1:
    img = Image.open(io.BytesIO(byteimg))                
    imgs.append(img)   