In [1]:
import torch
import torchvision
from models import DeepLabV3
from datasets.cityscapes import CityScapes
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision import transforms as T
import torchvision.transforms.functional as F
from torchvision.utils import make_grid
import torch.nn as nn
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import time
import os
from tqdm.notebook import tqdm
from utils import colorize_mask
from PIL import Image

In [3]:

# Data


root_dir = "F:\COMP90055\GMIDA\datas\CityScapes"
dataloader = CityScapes(root_dir,batch_size=1,split='val',shuffle=True)


device = torch.device("cuda:0")

weight = r"F:\COMP90055\GMIDA\train-runs\2022-09-22_22-44-38\checkpoint\deeplabv3-0.pth"
state_dict = torch.load(weight, map_location=device)
# print(state_dict)
model = DeepLabV3().to(device)      
model.load_state_dict(state_dict)

# output dir
run_time = time.strftime("%Y-%m-%d_%H-%M-%S",time.localtime())
output_dir = "./output/seg"
output_dir = os.path.join(output_dir,run_time)
if os.path.exists(output_dir):
    os.makedir(output_dir)

# visualize
def visualize(im,gt,pred):
    vis_transform = T.Compose([T.ToTensor()])

    im = im[0].data.cpu()
    im = np.array(im.permute(1,2,0))

    im = vis_transform(im)
    
    gt = gt[0].data.cpu().numpy()
    gt = colorize_mask(gt).convert('RGB')
    gt = vis_transform(gt)

    pred = pred[0].data.max(0)[1].cpu().numpy()

    pred = colorize_mask(pred).convert('RGB')
    plt.imshow(pred)
    pred = vis_transform(pred)
    
    grid = torch.stack([im,gt,pred],0)
    grid = make_grid(grid.cpu(), nrow=3, padding=5)
    return grid
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    plt.figure(figsize=(32,16))
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])



In [25]:
from tkinter import image_names


model.eval()
with torch.no_grad():
    for index,image,label in tqdm(dataloader):
        # print(image.size()) # B 3 1024 2048
        # print(torch.min(image),torch.max(image)) # 0~255
        
        resize = T.Resize((128,256))
        
        image = resize(image)

        # print(image.size()) # 4 3 128 256
        # plt.imshow(im[0].data.permute(1,2,0).numpy())
        # plt.show()
        # print(torch.min(image),torch.max(image)) # 0~255
        output = model((image*255).to(device))['out']
        # print(index)
        output = output[0].data.max(0)[1].cpu().numpy().astype(np.uint8)
        output = Image.fromarray(output)
        label = resize(label)[0].cpu().numpy().astype(np.uint8)
        
        label = Image.fromarray(label)
        output.save(os.path.join(r"F:\COMP90055\GMIDA\results\source_only\cityscapes128",index[0]+'.png'))
        label.save(os.path.join(r"F:\COMP90055\GMIDA\ground_truth\cityscapes128",index[0]+'.png'))
        # plt.imshow(label)
        # # print(output.shape)
        # # print(np.unique(output))
        # # grid = visualize(image,label,output)

        # show(grid)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))


