In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

import os
import sys
import torch
import numpy as np

working_directory = "/content/gdrive/MyDrive/WNet_project/WNet_R"
sys.path.insert(0, os.path.abspath(working_directory))
os.chdir(working_directory)

#!pip install crfseg 

from WNet import WNet
from configure import Config
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from matplotlib import pyplot as plt
from crfseg import CRF
from seg_plot import decode_segmap



# importing a given image from BSDS500
config = Config("test") 
image_name = "8068.jpg"
img_loc = os.path.join(config.datapath, image_name)
image = Image.open(img_loc).convert("RGB")
transform = transforms.Compose([
          transforms.Resize(config.inputsize),
          transforms.ToTensor()
        ])
tensor_image = transform(image)

# image plot 
plt.imshow(tensor_image.permute(1, 2, 0)); plt.axis('off'); plt.show()

# importing the model 
model = WNet(config.K, config.dropout, config.ch_mul, config.in_chans, config.out_chans)
model.load_state_dict(torch.load(os.path.join(config.saving_path, "21-04-15_19-36-39.pt"), map_location=torch.device('cpu')))
model.eval()

# softmax encoding
enc = model(torch.unsqueeze(tensor_image, 0), returns='enc')
soft_enc = F.softmax(enc, dim = 1)

# encoder plot
rgb = decode_segmap(soft_enc, nc=21)
plt.imshow(rgb); plt.axis('off'); plt.show()

#debug
print(torch.argmax(soft_enc.squeeze(), dim=0))

# decoder plot (reconstructed image)
dec = model(torch.unsqueeze(tensor_image, 0), returns='dec')
print(dec.squeeze().permute(1, 2, 0))
plt.imshow(dec.squeeze().permute(1, 2, 0).detach().cpu().numpy()); plt.axis('off'); plt.show(); 
plt.title('Reconstructed Image')

# CRF post-processing
CRF_model = CRF(n_spatial_dims=2) 
for i in range(config.CRF_num): 
    soft_enc = CRF_model(soft_enc)
rgb = decode_segmap(soft_enc, nc=21)

plt.imshow(rgb); plt.axis('off')
plt.title('CRF-postprocessed segmentation')