In [None]:
import cv2, json, ast

import numpy as np
import matplotlib.pyplot as plt
import paho.mqtt.client as mqtt

from torch.utils.data import DataLoader
from pytorch_lightning import Trainer

from anomalib.models.padim.lightning_model import PadimLightning
from anomalib.data.inference import InferenceDataset
from anomalib.pre_processing.transforms import Denormalize

In [None]:
def load_model(root):
    checkpoint = os.listdir(root)[0]
    print(checkpoint)
    model = PadimLightning.load_from_checkpoint(f"{root}/{checkpoint}")
    trainer = Trainer(progress_bar_refresh_rate=0)
    
    return model, trainer

In [None]:
def sub_photo(host, port, topic):

    def on_connect(client, userdata, flags, rc):
        print('Connected to PHOTO '+str(rc))
        client.subscribe(topic)
        
    def on_message(client, userdata, msg):
        data = json.loads(msg.payload.decode("utf8"))

        image = str(list(data.values())[0])
        c = int(list(data.keys())[0])
        timestamp = str(list(data.values())[1])

        with open(f'./data/timestamp.txt','w') as f:
            f.write(timestamp)
        print('timestamp written')

        with open(f'./data/image_{c}.txt','w') as f:
            f.write(image)
        print('image written')
        
        if c == 2:
            client.disconnect()

    client = mqtt.Client()
    client.on_connect = on_connect
    client.on_message = on_message

    client.connect(host, port, keepalive=180)
    client.loop_forever()

    return None

In [None]:
def pub_results(model, trainer,image, host, port, topic):
    dataset = DataLoader(InferenceDataset(f"./data/webcam_images/{image}.png", image_size=tuple([288,288])))
    output = trainer.predict(model=model, dataloaders=dataset)[0]
    pred_label = str(int(output['pred_labels'].tolist()[0]))

    client = mqtt.Client()
    client.connect(host, port, keepalive=180)
    client.publish(topic, pred_label)
    #print('RESULTS published')
    return output

In [None]:
def text_to_img():
    image = np.zeros((288,288,3)) 
    timestamp = open(f'./data/timestamp.txt', "r").read()
    for channel in range(3):
        f = open(f'./data/image_{channel}.txt', "r")
        output = f.read()
        msg_list = ast.literal_eval(output)
        msg_ndar = np.asarray(msg_list)
        image[:,:,channel] = msg_ndar
    cv2.imwrite(f'./data/webcam_images/webcam_{timestamp}.png', image)
    return image, timestamp

In [None]:

port = 1883
model, trainer = load_model('./results/padim/big_50/run/weights/')

In [None]:

dataset = DataLoader(InferenceDataset(f"./que.jpg", image_size=tuple([288,288])))
output = trainer.predict(model=model, dataloaders=dataset)[0]
pred_label = str(int(output['pred_labels'].tolist()[0]))

In [None]:
ip = "192.168.137.70"
outputs = {}

In [None]:
while True:
    sub_photo(ip, port, 'photo')
    print("Foto recibida")
    
    img, timestamp = text_to_img()
    
    #print("Publicando resultados")
    output = pub_results(model, trainer,f'webcam_{timestamp}', ip, port, 'results')
    outputs[f'webcam_{timestamp}'] = output
    print("Resultados publicados")

In [None]:
def show_image(path):
    output = outputs[path]
    image = Denormalize()(output["image"][0])
    #print(f"Image Shape: {image.shape}\n Min Pixel: {image.min()} \n Max Pixel: {image.max()}")
    #plt.imshow(image)
    return image

def show_anomaly_map(path):
    output = outputs[path]
    anomaly_map = output["anomaly_maps"][0].cpu().numpy().squeeze()
    #plt.imshow(anomaly_map)
    return anomaly_map
    
def show_pred_maks(path):
    output = outputs[path]
    pred_masks = output["pred_masks"][0].squeeze().cpu().numpy()
    #plt.imshow(pred_masks)
    return pred_masks

def show_results(path):
    image = show_image(path)
    anomaly_map = show_anomaly_map(path)
    pred_masks = show_pred_maks(path)

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))

    ax1.imshow(image)
    ax2.imshow(anomaly_map)
    ax3.imshow(pred_masks)

    ax1.set_title('Image')
    ax2.set_title('Anomaly map')
    ax3.set_title('Predicted masks')

    ax1.axis('off')
    ax2.axis('off')
    ax3.axis('off')

    plt.show()