# Inference fine-tuned OneFormer on test dataset

In [None]:
from PIL import Image
from transformers import AutoProcessor
from transformers import AutoModelForUniversalSegmentation
import evaluate
import torch
import os
import numpy as np
import pandas as pd
from glob import glob

In [None]:
# images_folder = r"C:\Users\lliu\Desktop\FrontierSI\projects\GA_floor_height\GA-floor-height\output\Wagga\GSV_annotations_converted_merged\all\images"
# model_folder = r"C:\Users\lliu\Desktop\FrontierSI\projects\GA_floor_height\GA-floor-height\output\oneformer\from_all"
# out_folder=r'C:\Users\lliu\Desktop\FrontierSI\projects\GA_floor_height\GA-floor-height\output\Wagga\GSV_prediction\OneFormer\from_all'

In [None]:
images_folder = r"D:\Wagga\RICS\all_images"
model_folder = r"D:\Wagga\RICS\OneFormer\from_all"
out_folder=r'D:\Wagga\RICS\all_images_predicted'

### Read in GSV image

In [None]:
# image_files = sorted(glob(f"{images_folder}/*.png"))
image_files = glob(f"{images_folder}/*.jpg")

In [None]:
# image_file=image_files[0]
image_file = r'D:\Wagga\RICS\all_images\Industrial_131112_04706_L_0007862.jpg'
image = Image.open(image_file)

### Load pre-trained model and initialise processor

In [None]:
# id2label =  {1:'front door',2:'foundation',3:'garage door',4:'pavement'}
id2label = {0:"_background_", 1:"foundation", 2:"front door", 3:"garage door", 4:"stairs"}
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}

In [None]:
processor = AutoProcessor.from_pretrained(model_folder)
# encoded_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt",size=(512,512))
encoded_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt",do_resize=False)
# processor.tokenizer.batch_decode(encoded_inputs.task_inputs)

In [None]:
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 
device='cpu'
model = AutoModelForUniversalSegmentation.from_pretrained(model_folder,is_training=False,
                                                        ignore_mismatched_sizes=True,
                                                        num_labels=len(label2id), 
                                                        id2label=id2label, 
                                                        label2id=label2id)
# model = AutoModelForUniversalSegmentation.from_pretrained(model_path)
model.to(device)
model.eval()

### Inference

In [None]:
# forward pass
with torch.no_grad():
  outputs = model(**encoded_inputs)

### Post process

In [None]:
predicted_segmentation_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[(image.size[1],image.size[0])])[0]
predicted_segmentation_map.shape

### Save predictions

In [None]:
# image = Image.fromarray(np.array(semantic_segmentation).astype(np.uint8))
prediction_arr = Image.fromarray(np.array(predicted_segmentation_map).astype(np.uint8))

# Save the image as a JPG
out_prediction = os.path.join(out_folder,os.path.basename(image_file).replace('jpg','png'))
prediction_arr.save(out_prediction)

## Put together and do for all validation images

In [None]:
predictions=[]
gt_labels=[]
for image_file in image_files:
    out_prediction = os.path.join(out_folder,os.path.basename(image_file).replace('jpg','png'))
    if os.path.exists(out_prediction):
        print('prediction exists, skipping...')
    else:
        image = Image.open(image_file)
        # encoded_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt",size=(512,512))
        encoded_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt",do_resize=False)
        # processor.tokenizer.batch_decode(encoded_inputs.task_inputs)
        # forward pass
        with torch.no_grad():
            outputs = model(**encoded_inputs)
        predicted_segmentation_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[(image.size[1],image.size[0])])[0]

        # Save prediction image as a JPG
        image_predicted = Image.fromarray(np.array(predicted_segmentation_map).astype(np.uint8))
        image_predicted.save(out_prediction,compress_level=1)
