In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from numpy import *
import os
import numpy as np
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms
from torchvision.transforms import functional as tf
#from torchvision.transforms import F as tf

import glob
from PIL import Image

import sys
import time
%matplotlib inline
%config InlineBackend.figure_format = "retina"


import random
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
# set random seed
setup_seed(88)



In [56]:
#————————————————————————————————————————————————————————————————————————————————————————
#  define transform

transform = transforms.Compose([
                                # transforms.Resize((512,512)), # resize 成 512,512
                                transforms.ToTensor(), # 
])

class Dataset(data.Dataset):
    def __init__(self, imgs_path, annos_path,cropp=False):
        self.imgs_path = imgs_path
        self.annos_path = annos_path
        self.cropp = cropp
    def crop(self,img,label,size):
        t = transforms.RandomCrop.get_params(img=img,output_size=(size,size))
        img = tf.crop(img, *t)
        label = tf.crop(label, *t)
        return img,label
    
    def __getitem__(self, index):
        img = self.imgs_path[index]
        anno = self.annos_path[index]

        pil_img = Image.open(img)#.convert('RGB')
        anno_img = Image.open(anno).convert('L')
        
        if self.cropp:
            pil_img,anno_img = self.crop(pil_img,anno_img,128)
        
        img_tensor = transform(pil_img)
        img_tensor = img_tensor.to(torch.float)

        
        anno_tensor = transform(anno_img)
        # the image is [256 256 1] 
        
        
        #anno_tensor[anno_tensor>0] = 1 
        anno_tensor = anno_tensor*255. # set the value of the label from 0.. 1 2 ... ，

        # change the image to [256,256]
        anno_tensor = torch.squeeze(anno_tensor).type(torch.long) 

        return img_tensor, anno_tensor
    
    def __len__(self):
        return len(self.imgs_path)

In [57]:
#                           Unet 
#——————————————————————————————————————————————————————————————————————
#
# build the Unet model 
###
#    Unet
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(out_channels),
                                       nn.ReLU(inplace=True),
                                       nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(out_channels),
                                       nn.ReLU(inplace=True)
        )
        self.pool = nn.MaxPool2d(kernel_size = 2)
    
    def forward(self, x, is_pool=True):

        if is_pool: # 
            x = self.pool(x)

        x = self.conv_relu(x)

        return x

# up-sample layer
class Upsample(nn.Module):
    def __init__(self, channels):
        super(Upsample, self).__init__()
        self.conv_relu = nn.Sequential(nn.Conv2d(2*channels, channels, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(channels),
                                       nn.ReLU(inplace=True),
                                       
                                       nn.Conv2d(channels, channels, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(channels),
                                       nn.ReLU(inplace=True),
        )
        self.upconv = nn.Sequential(nn.ConvTranspose2d(channels, channels//2, kernel_size=3, stride=2,padding=1,output_padding=1)
        )
    
    def forward(self, x):
        x = self.conv_relu(x)
        x = self.upconv(x)
        return x

# 
class Unet_model(nn.Module):
    def __init__(self):
        super(Unet_model, self).__init__()
        self.down1 = Downsample(1,64) # if the input image is in 3 channel, change 1 to 3
        self.down2 = Downsample(64,128)
        self.down3 = Downsample(128,256)
        self.down4 = Downsample(256,512)
        self.down5 = Downsample(512,1024)

        self.up = nn.Sequential(nn.ConvTranspose2d(1024,512,kernel_size=3,stride=2,padding=1,output_padding=1),
                                #nn.Dropout(p=0.5),
                                nn.BatchNorm2d(512),
                                nn.ReLU(inplace=True)
        )

        self.up1 = Upsample(512)
        self.up2 = Upsample(256)
        self.up3 = Upsample(128)

        self.conv_2 = Downsample(128,64)

        self.last = nn.Sequential(nn.Conv2d(64,26,kernel_size=1),
                                  #nn.Dropout(p=0.5)
                                  ) # 26=25+1 25 classes
    
    def forward(self, input):
        x1 = self.down1(input, is_pool=False)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)

        x5 = self.up(x5)

        x5 = torch.cat([x4,x5], dim=1) #
        x5 = self.up1(x5) 

        x5 = torch.cat([x3,x5], dim=1)
        x5 = self.up2(x5) 

        x5 = torch.cat([x2,x5], dim=1)
        x5 = self.up3(x5) 

        x5 = torch.cat([x1,x5], dim=1)

        x5 = self.conv_2(x5, is_pool=False)

        x5 = self.last(x5)

        return x5

    
# load the weight trained before


In [58]:



# load weight to model
PATH = './.............pth' # locate to the weights path
model =  Unet_model()
model.load_state_dict(torch.load(PATH,map_location=torch.device('cpu') ))

# load data
# read the png files
MB_R = glob.glob('........./*.png');MB_R = sorted(MB_R);
print('MB_R: ',len(MB_R))

# read the label files
MB_anno_R = glob.glob('......../*.png');print(len(MB_anno_R));

test_images = MB_R
test_images = sorted(test_images) 
test_annos = MB_anno_R
test_annos = sorted(test_annos) 
test_dataset = Dataset(test_images,test_annos)


# 
test_dataloader = data.DataLoader(test_dataset,
                                   batch_size = 1,
                                   shuffle=False,
                                  num_workers=0) # 

model.eval()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)


# save to .npy file [26,256,256]
i = 0
with torch.no_grad():
    for x, y in test_dataloader:
        i = i+1
        x = x.to(device)
        y = y.to(device)
        y_pred = model(x)
        y_pred = y_pred.squeeze().cpu().numpy()


        if i < 10:
          filename = './ '.npy'
          np.save(filename,y_pred)
        elif i>=10 and i<=99:
          filename = './' '.npy'
          np.save(filename,y_pred)
        else:
          filename = './' '.npy'
          np.save(filename,y_pred)

        

        print(y_pred.shape)
        # 


MB_R:  521
521
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 256)
(26, 256, 