In [1]:
from google.colab import drive
import numpy as np
import cv2

drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
from pathlib import Path
import os

from PIL import Image

In [3]:
PROJECT_DIR = '/content/drive/MyDrive/CS 7150/project/'

## Image Segmentation

### Build environment

In [None]:
! git clone https://github.com/microsoft/unilm.git
! cd unilm/beit2 ; pip install -r requirements.txt

! pip install openmim
! mim install mmcv-full==1.3.0
! pip install scipy timm==0.3.2 mmsegmentation==0.11.0

In [5]:
# * Download model weights
# !wget https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21ktoade20k.pth
ckpt_path = f'"{PROJECT_DIR}beitv2_large_patch16_224_pt1k_ft21ktoade20k.pth"'
! cp $ckpt_path beitv2_large_patch16_224_pt1k_ft21ktoade20k.pth

### Build model

In [6]:
import sys
sys.path.append('unilm/beit2/semantic_segmentation')
from backbone import beit

apex is not installed
apex is not installed
apex is not installed
apex is not installed


In [7]:
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot

In [8]:
config_file = 'unilm/beit2/semantic_segmentation/configs/beit/upernet/upernet_beit_large_24_512_slide_160k_21ktoade20k.py'
checkpoint_file = 'beitv2_large_patch16_224_pt1k_ft21ktoade20k.pth'
model = init_segmentor(config_file, checkpoint_file, device='cuda:0')

Use load_from_local loader


### Model inference

In [None]:
from skimage import io
from skimage import color

In [52]:
test_dir = Path(PROJECT_DIR + 'sample_imgs')

input_imgs = os.listdir(test_dir / 'input')
target_imgs = os.listdir(test_dir / 'target')

In [54]:
for img in input_imgs:
    result = inference_segmentor(model, str(test_dir / 'input' / img))
    seg_map = color.label2rgb(result[0], io.imread(str(test_dir / 'input' / img)))
    cv2.imwrite(str(test_dir / 'seg_in' / f'seg_{img}'), seg_map)
    show_result_pyplot(model, str(test_dir / 'input' / img), result)

for img in target_imgs:
    result = inference_segmentor(model, str(test_dir / 'target' / img))
    seg_map = color.label2rgb(result[0], io.imread(str(test_dir / 'target' / img)))
    cv2.imwrite(str(test_dir / 'seg_tar' / f'seg_{img}'), seg_map)
    show_result_pyplot(model, str(test_dir / 'target' / img), result)

Output hidden; open in https://colab.research.google.com to view.

## Color Transfer

### Build environment

In [None]:
code_path = f'"{PROJECT_DIR}origin_code"'
! cd $code_path ; pip install -r requirements.txt

### Model inference

In [55]:
os.chdir(code_path[1:-1])

In [56]:
! python test.py --dataroot ../sample_imgs --checkpoints_dir ../color_transfer_model --results_dir ../sample_imgs/result_sr --is_SR
! python test.py --dataroot ../sample_imgs --checkpoints_dir ../color_transfer_model --results_dir ../sample_imgs/result

  init.normal(m.weight.data, 0.0, 0.02)
  init.normal(m.weight.data, 0.0, 0.02)
  "See the documentation of nn.Upsample for details.".format(mode))

image:  0 / 5
0000: process image... 

image:  1 / 5
0001: process image... 

image:  2 / 5
0002: process image... 

image:  3 / 5
0003: process image... 

image:  4 / 5
0004: process image... 
  init.normal(m.weight.data, 0.0, 0.02)
  init.normal(m.weight.data, 0.0, 0.02)
  "See the documentation of nn.Upsample for details.".format(mode))

image:  0 / 5
0000: process image... 

image:  1 / 5
0001: process image... 

image:  2 / 5
0002: process image... 

image:  3 / 5
0003: process image... 

image:  4 / 5
0004: process image... 


### Visualization

In [57]:
import matplotlib.pyplot as plt
from PIL import Image

In [58]:
result_dir = PROJECT_DIR + 'sample_imgs/result/'
result_sr_dir = PROJECT_DIR + 'sample_imgs/result_sr/'

results = os.listdir(result_dir)

In [59]:
num_images  = len(results) // 3 * 4
num_cols    = 4
num_rows    = len(results) // 3
col_map = {0:'Input',1:'Target',2:'Without seg',3:'With seg'}

fig, axes = plt.subplots(num_rows, num_cols, figsize=(20,20))
list_axes = list(axes.flat)

for i in range(num_images):
    exp_no = i // 4
    col_no = i % 4
    if col_no != 3:
        img = Image.open(result_dir + results[exp_no*3+col_no])
    else:
        img = Image.open(result_sr_dir + results[exp_no*3+2])
    
    list_axes[i].imshow(img)
    if i<4:
        list_axes[i].set_title(col_map[col_no], fontsize=25) 

for i in range(num_images, len(list_axes)):
    list_axes[i].set_visible(False)

fig.tight_layout()
_ = plt.show()

Output hidden; open in https://colab.research.google.com to view.