In [None]:
# Install MMCV
!pip install openmim
!mim install mmcv-full
!pip install git+https://github.com/open-mmlab/mmsegmentation.git

In [None]:
%cd /content/drive/MyDrive/thesis/test-beit

In [None]:
from mmseg.datasets.pipelines import Compose
from mmcv.parallel import collate, scatter
from mmseg.apis import init_segmentor

import torch.nn.functional as F
import matplotlib.pyplot as plt 
from scipy.stats import entropy
import numpy as np
import torch
import mmcv
import os

In [None]:
class LoadImage:
    """A simple pipeline to load image."""

    def __call__(self, results):
        """Call function to load images into results.

        Args:
            results (dict): A result dict contains the file name
                of the image to be read.

        Returns:
            dict: ``results`` will be returned containing loaded image.
        """

        if isinstance(results['img'], str):
            results['filename'] = results['img']
            results['ori_filename'] = results['img']
        else:
            results['filename'] = None
            results['ori_filename'] = None
        img = mmcv.imread(results['img'])
        results['img'] = img
        results['img_shape'] = img.shape
        results['ori_shape'] = img.shape
        return results

In [None]:
config_file = 'configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py'
checkpoint_file = 'checkpoints/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth'

In [None]:
model = init_segmentor(config_file, checkpoint_file, device='cuda:0')

In [None]:
directories = ['barrel']
imgs = []
for directory in directories:
  for filename in os.listdir(directory):
    imgs.append(os.path.join(directory, filename))

In [None]:
# imgs = 'barrel_3.jpg'
imgs = '000001.jpg'
device = 'cuda'

In [None]:
cfg = model.cfg
# build the data pipeline
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)

# prepare data
data = []
imgs = imgs if isinstance(imgs, list) else [imgs]
for img in imgs:
    img_data = dict(img=img)
    img_data = test_pipeline(img_data)
    data.append(img_data)
data = collate(data, samples_per_gpu=len(imgs))
if next(model.parameters()).is_cuda:
    # scatter to specified GPU
    data = scatter(data, [device])[0]
else:
    data['img_metas'] = [i.data[0] for i in data['img_metas']]

In [None]:
alpha = 0.2
save_dir = 'mask'

In [None]:
with torch.no_grad():
  class_softmax = model.inference(data['img'][0], data['img_metas'][0], True)
# output = F.softmax(class_probs, dim=1)
output = class_softmax.argmax(dim=1)
class_entropy = entropy(class_softmax.cpu(), axis=1)
threshold = np.percentile(class_entropy.flatten(), 100 * (1 - alpha))
refine_mask = np.where(class_entropy[0] < threshold, 1, 0)
refine_output = output.cpu() * refine_mask
object_mask = np.where(refine_output == 111, 1, 0)

In [None]:
plt.imshow(object_mask[0], cmap='gray')

In [None]:
for i in range(len(imgs)):
  with torch.no_grad():
    class_softmax = model.inference(data['img'][i], data['img_metas'][i], True)
  output = class_softmax.argmax(dim=1)
  class_entropy = entropy(class_softmax.cpu(), axis=1)
  threshold = np.percentile(class_entropy.flatten(), 100 * (1 - alpha))
  refine_mask = np.where(class_entropy[0] < threshold, 1, 0)
  refine_output = output.cpu() * refine_mask
  object_mask = np.where(refine_output == 111, 1, 0)
  if new_img == 0:
    continue
  save_path = os.path.join(save_dir, data['img_metas'][i]['filename'])
  cv2.imwrite(new_img, save_path)