# Inference fine-tuned OneFormer on test dataset

In [13]:
from PIL import Image
from transformers import AutoProcessor
from transformers import AutoModelForUniversalSegmentation
import evaluate
import torch
import os
import numpy as np
import pandas as pd
from glob import glob

In [14]:
images_folder = r"D:\Launceston\GSV\Pano_clipped"
model_folder = r"D:\Wagga\RICS\OneFormer\from_all"
out_folder=r'D:\Launceston\GSV\Panos_clipped_predicted'

In [15]:
# images_folder = r"D:\Wagga\RICS\all_images"
# model_folder = r"D:\Wagga\RICS\OneFormer\from_all"
# out_folder=r'D:\Wagga\RICS\all_images_predicted'

### Read in GSV image

In [16]:
# image_files = sorted(glob(f"{images_folder}/*.png"))
image_files = glob(f"{images_folder}/*.jpg")

In [17]:
image_file=image_files[0]
image = Image.open(image_file)

### Load pre-trained model and initialise processor

In [18]:
# id2label =  {1:'front door',2:'foundation',3:'garage door',4:'pavement'}
id2label = {0:"_background_", 1:"foundation", 2:"front door", 3:"garage door", 4:"stairs"}
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}

In [19]:
processor = AutoProcessor.from_pretrained(model_folder)
# encoded_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt",size=(512,512))
encoded_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt",do_resize=False)
# processor.tokenizer.batch_decode(encoded_inputs.task_inputs)

Some kwargs in processor config are unused and will not have any effect: task_seq_length, max_seq_length. 


In [20]:
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 
device='cpu'
model = AutoModelForUniversalSegmentation.from_pretrained(model_folder,is_training=False,
                                                        ignore_mismatched_sizes=True,
                                                        num_labels=len(label2id), 
                                                        id2label=id2label, 
                                                        label2id=label2id)
# model = AutoModelForUniversalSegmentation.from_pretrained(model_path)
model.to(device)
model.eval()

Some weights of the model checkpoint at D:\Wagga\RICS\OneFormer\from_all were not used when initializing OneFormerForUniversalSegmentation: ['model.text_mapper.prompt_ctx.weight', 'model.text_mapper.text_encoder.ln_final.bias', 'model.text_mapper.text_encoder.ln_final.weight', 'model.text_mapper.text_encoder.positional_embedding', 'model.text_mapper.text_encoder.token_embedding.weight', 'model.text_mapper.text_encoder.transformer.layers.0.layer_norm1.bias', 'model.text_mapper.text_encoder.transformer.layers.0.layer_norm1.weight', 'model.text_mapper.text_encoder.transformer.layers.0.layer_norm2.bias', 'model.text_mapper.text_encoder.transformer.layers.0.layer_norm2.weight', 'model.text_mapper.text_encoder.transformer.layers.0.mlp.fc1.bias', 'model.text_mapper.text_encoder.transformer.layers.0.mlp.fc1.weight', 'model.text_mapper.text_encoder.transformer.layers.0.mlp.fc2.bias', 'model.text_mapper.text_encoder.transformer.layers.0.mlp.fc2.weight', 'model.text_mapper.text_encoder.transforme

OneFormerForUniversalSegmentation(
  (model): OneFormerModel(
    (pixel_level_module): OneFormerPixelLevelModule(
      (encoder): SwinBackbone(
        (embeddings): SwinEmbeddings(
          (patch_embeddings): SwinPatchEmbeddings(
            (projection): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
          )
          (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (encoder): SwinEncoder(
          (layers): ModuleList(
            (0): SwinStage(
              (blocks): ModuleList(
                (0): SwinLayer(
                  (layernorm_before): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
                  (attention): SwinAttention(
                    (self): SwinSelfAttention(
                      (query): Linear(in_features=192, out_features=192, bias=True)
                      (key): Linear(in_features=192, out_features=192, bias=True)
                      (value): Li

### Inference

In [21]:
# forward pass
with torch.no_grad():
  outputs = model(**encoded_inputs)

### Post process

In [22]:
predicted_segmentation_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[(image.size[1],image.size[0])])[0]
predicted_segmentation_map.shape

torch.Size([2867, 3640])

### Save predictions

In [23]:
# image = Image.fromarray(np.array(semantic_segmentation).astype(np.uint8))
prediction_arr = Image.fromarray(np.array(predicted_segmentation_map).astype(np.uint8))

# Save the image as a png
out_prediction = os.path.join(out_folder,os.path.basename(image_file).replace('jpg','png'))
prediction_arr.save(out_prediction)

## Put together and do for all validation images

In [None]:
predictions=[]
gt_labels=[]
for image_file in image_files:
    out_prediction = os.path.join(out_folder,os.path.basename(image_file).replace('jpg','png'))
    if os.path.exists(out_prediction):
        print('prediction exists, skipping...')
    else:
        image = Image.open(image_file)
        # encoded_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt",size=(512,512))
        encoded_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt",do_resize=False)
        # processor.tokenizer.batch_decode(encoded_inputs.task_inputs)
        # forward pass
        with torch.no_grad():
            outputs = model(**encoded_inputs)
        predicted_segmentation_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[(image.size[1],image.size[0])])[0]

        # Save prediction image as a JPG
        image_predicted = Image.fromarray(np.array(predicted_segmentation_map).astype(np.uint8))
        image_predicted.save(out_prediction,compress_level=1)


prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
prediction exists, skipping...
predicti