In [1]:
import torch
import torch.nn as nn
import os
from PIL import Image
import PIL
import ttach as tta
import yaml
import albumentations as A
import numpy as np
import json
from albumentations.pytorch import ToTensorV2
device=torch.device("cuda:0")
config_path="predict_config.yaml"
with open(config_path) as f:
     predict_hyps = yaml.load(f, Loader=yaml.FullLoader)
path_to_images=predict_hyps["path_to_images"]
detection_hyps={0:{"class":"Human","conf_tr":0.57},1:{"class":"Car","conf_tr":0.63},2:{"class":"Wagon","conf_tr":0.42},3:{"class":"Signal","conf_tr":0.44},4:{"class":"FacingSwitch","conf_tr":0.7},5:{"class":"TrailingSwitch","conf_tr":0.55}}
segmentation_model = torch.load(predict_hyps["segmentation"]).eval().half().to(device)
transform_seg = A.Compose([
            A.Resize(height=1536,width=2688),
            A.Normalize([predict_hyps['mean0'], predict_hyps['mean1'], predict_hyps['mean2']], [predict_hyps['std0'],predict_hyps['std1'], predict_hyps['std2']]),
            ToTensorV2()])
transforms_tta_seg = tta.Compose(
    [
        tta.Resize(sizes=[(640, 1280),(960,1920),(1536,2688)], original_size=(1536, 2688), interpolation="nearest")  
    ]
)
transform_cls_switches= A.Compose([
            A.Resize(height=128,width=256),
            A.Normalize([predict_hyps['mean0'], predict_hyps['mean1'], predict_hyps['mean2']], [predict_hyps['std0'],predict_hyps['std1'], predict_hyps['std2']]),
            ToTensorV2()])
transform_cls_signal= A.Compose([
            A.Resize(height=128,width=128),
            A.Normalize([predict_hyps['mean0'], predict_hyps['mean1'], predict_hyps['mean2']], [predict_hyps['std0'],predict_hyps['std1'], predict_hyps['std2']]),
            ToTensorV2()])

In [None]:
for image_file in os.listdir(predict_hyps["path_to_images"]):
    img = Image.open(os.path.join(predict_hyps["path_to_images"],image_file)).convert("RGB")
    orig_size=img.size
    sample = {"image": img}
    sample['image'] = np.array(sample['image'])[:, :, ::-1]
    sample = transform_seg(**sample)
    image= sample["image"].unsqueeze(0)
    final_mask=torch.zeros((1,4,1536,2688)).to(device)
    for transformer in transforms_tta_seg:
        with torch.no_grad():
            augmented_image = transformer.augment_image(image).to(device)
            model_output = segmentation_model(augmented_image.half())
            deaug_mask = transformer.deaugment_mask(model_output.float())
            final_mask += deaug_mask
    result_mask = torch.argmax(torch.softmax(final_mask/3,dim=1),dim=1).squeeze(0).detach().cpu().numpy()
    result_mask=np.stack([result_mask,result_mask,result_mask])
    result_mask[result_mask==1]=6
    result_mask[result_mask==2]=7
    result_mask[result_mask==3]=10
    Image.fromarray(np.transpose(result_mask, (1, 2, 0)).astype(np.uint8)).resize(orig_size,PIL.Image.NEAREST).save(os.path.join(predict_hyps["mask_save_path"],image_file))
del segmentation_model

In [None]:
!python3 ../detection/yolov5/detect.py --weights ../models_checkpoints/detection_inference.pt --source $path_to_images --conf-thres 0.4 --augment  --project=./detector_predictions

In [None]:
classification_signal = torch.load(predict_hyps["signal"]).eval().half().to(device)
classification_t_switch = torch.load(predict_hyps["tswitch"]).eval().half().to(device)
classification_f_switch = torch.load(predict_hyps["fswitch"]).eval().half().to(device)
classificators_dict={"Signal":{"model":classification_signal,"tr":transform_cls_signal,0:"SignalE",1:"SignalF"},"FacingSwitch":{"model":classification_f_switch,"tr":transform_cls_switches,0:"FacingSwitchR",1:"FacingSwitchL",2:"FacingSwitchNV"},"TrailingSwitch":{"model":classification_t_switch,"tr":transform_cls_switches,0:"TrailingSwitchR",1:"TrailingSwitchL",2:"TrailingSwitchNV"}}
for image_file in os.listdir(predict_hyps["path_to_images"]):
    img = Image.open(os.path.join(predict_hyps["path_to_images"],image_file)).convert("RGB")
    orig_size=img.size
    bb_dict={}
    bb_dict["img_size"]={"height":orig_size[1],"width":orig_size[0]}
    bb_dict["bb_objects"]=[]
    with open(os.path.join(predict_hyps["json_save_path"],image_file)+".json","w") as bb_predictions_json:
        if os.path.exists(os.path.join("./detector_predictions/predictions/labels",image_file.strip(".png")+".txt"))==True:
            with open(os.path.join("./detector_predictions/predictions/labels",image_file.strip(".png")+".txt"),"r") as yolo_pred:
                for line in yolo_pred:
                    line =  line.strip("\n")
                    cls,x,y,w,h,prob = line.split(" ")
                    threshold = detection_hyps[int(cls)]["conf_tr"]
                    if float(prob) > threshold:
                        x1 = int((float(x) - float(w) / 2) * orig_size[0])
                        y1 = int((float(y) - float(h) / 2) * orig_size[1])
                        x2 = int((float(x) + float(w) / 2) * orig_size[0])
                        y2 = int((float(y) + float(h) / 2) * orig_size[1])
                        if int(cls) in [3,4,5]:
                            crop = img.crop((x1,y1,x2,y2))
                            sample={}
                            sample["image"]=crop
                            sample['image'] = np.array(sample['image'])[:, :, ::-1]
                            sample=classificators_dict[detection_hyps[int(cls)]["class"]]["tr"](**sample)
                            tensor_img = sample["image"].unsqueeze(0)
                            with torch.no_grad():
                                out=classificators_dict[detection_hyps[int(cls)]["class"]]["model"](tensor_img.half().to(device))
                                cls_predicted=torch.argmax(torch.softmax(out,dim=1),dim=1).float().squeeze(0).detach().cpu().numpy()
                                str_cls = classificators_dict[detection_hyps[int(cls)]["class"]][int(cls_predicted)]
                                detected_object={}
                                detected_object["x1"]=x1
                                detected_object["y1"]=y1
                                detected_object["x2"]=x2
                                detected_object["y2"]=y2
                                detected_object["class"]=str_cls
                                bb_dict["bb_objects"].append(detected_object)
                        else:
                            detected_object={}
                            str_cls = detection_hyps[int(cls)]["class"]
                            detected_object["x1"]=x1
                            detected_object["y1"]=y1
                            detected_object["x2"]=x2
                            detected_object["y2"]=y2
                            detected_object["class"]=str_cls
                            bb_dict["bb_objects"].append(detected_object)
                    else:
                        continue
                yolo_pred.close()
            json.dump(bb_dict,bb_predictions_json)
            bb_predictions_json.close()
        else:
            json.dump(bb_dict,bb_predictions_json)
            bb_predictions_json.close()
                