In [1]:
import torch
from torchvision import transforms as tfs
from torch.utils.data import DataLoader
from utils import hyperparameters
from dataset import *
from generator import Generator
from communications import TCPClient
import socket

In [2]:
seq_len = 50

netparams = hyperparameters(w=320, 
                            h=239, 
                            latent_dim=512, 
                            history_length=8, 
                            future_length=12,
                            cnn_filters=["32", "64", "128", "256", "512"],
                            lin_neurons=["1024", "1024"],
                            enc_layers=2,
                            lstm_dim=512,
                            output_dim=2,
                            up_criterion=0.9,
                            down_criterion=0.0,
                            seq_len=seq_len,
                            attention='add')

image_transforms = tfs.Compose([tfs.Resize((320, 239)),
                               tfs.ToTensor(),
                               tfs.Normalize([0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5])])

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
gen = Generator(netparams, device).to(device)
checkpoint = torch.load('gen_Glre3Dlre4_64btch_120ep.pth')
gen.load_state_dict(checkpoint)
a = gen.eval()

In [3]:
HEADERSIZE = 16
TEXTCODING = 'utf-8'
BUFFERREAD = 1024
PORT = 8080

s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) #Socket Family and socket type
s.bind((socket.gethostname(), PORT)) #IP direction and port
s.listen(5)             

In [4]:
clientsocket, address = s.accept()
print("Conection from", address, "has been established!")
client = TCPClient(clientsocket, HEADERSIZE, TEXTCODING, BUFFERREAD)
client.Handshake()
waitRoute = torch.zeros((1, 12, 3))
randomRoute = torch.rand((1, 12, 3))
predicted_routes = []

while client.connected:
    
    imgs, routes, objective, new_simulation = client.Receive()
    if len(imgs) == 0 or len(routes) == 0 or len(objective) == 0:
        client.Send(waitRoute)
        #client.Send(randomRoute)
        continue

    if new_simulation: #Maintain noise vector trought each simultaion
        noise = torch.randn((netparams['history'], netparams['latent_dim']))
        print("New simulation starting")
        
    test_set = OnlineProcessing(imgs, 
                                routes, 
                                objective, 
                                netparams['history'], 
                                netparams['predict_seq'], 
                                noise, 
                                image_transforms)
    
    test_loader = DataLoader(test_set, batch_size=1, shuffle=False)
    
    for imgs_t, z_t, past_routes_t, objective_t in test_loader:
        imgs_t = imgs_t.to(device)
        z_t = z_t.to(device)
        past_routes_t = get_relative(past_routes_t)
        past_routes_t = past_routes_t.to(device)
        objective_t = objective_t.to(device)
        predicted_route = gen(imgs_t, z_t, past_routes_t, objective_t)
        
    client.Send(predicted_route)
    predicted_routes.append(predicted_route.cpu().detach().numpy())


Conection from ('127.0.0.1', 45162) has been established!
Handshake successful
New simulation starting
Connection shut down from the client


In [5]:
for pd in predicted_routes:
    print(pd)

[[[ 0.99999994 -0.9994938  -0.99998885]
  [ 0.99999994 -0.99875784 -0.9999184 ]
  [ 0.9999999  -0.99941415 -0.9999517 ]
  [ 0.9999998  -0.9994514  -0.999932  ]
  [ 0.9999999  -0.9993624  -0.99993366]
  [ 0.9999999  -0.99927056 -0.9999321 ]
  [ 0.9999999  -0.9991398  -0.99993134]
  [ 0.9999999  -0.9989621  -0.9999315 ]
  [ 0.9999999  -0.9987208  -0.99993205]
  [ 0.9999999  -0.9983983  -0.9999328 ]
  [ 0.9999999  -0.9979863  -0.99993366]
  [ 0.9999999  -0.99749845 -0.9999346 ]]]
[[[ 0.9999995  -0.9395151  -0.99999535]
  [ 0.99999994 -0.6973007  -0.9983296 ]
  [ 0.9999998  -0.74881774 -0.99839526]
  [ 0.9999997  -0.39546034 -0.9973215 ]
  [ 0.99999946  0.64080733 -0.9970801 ]
  [ 0.9999991   0.9293947  -0.9971864 ]
  [ 0.99999857  0.9693968  -0.99742275]
  [ 0.9999978   0.9923549  -0.9983064 ]
  [ 0.9999969   0.9993303  -0.99936986]
  [ 0.99999577  0.9998259  -0.99952984]
  [ 0.99999434  0.999877   -0.9993809 ]
  [ 0.99999267  0.9998864  -0.9991419 ]]]
[[[ 0.99999976 -0.25423387 -0.999992

In [13]:
import matplotlib.pyplot as plt
from PIL import Image

for img in imgs:
    plt.figure()
    plt.imshow(img)

In [12]:
client.full_msg

b''