In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
from torchvision import models,transforms,datasets
import torchvision
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch import optim
import sklearn
from PIL import Image
from torch.autograd import Variable
import os
import torchvision.transforms.functional as TF

In [None]:
def colormap(n):
    cmap=np.zeros([n, 3]).astype(np.uint8)

    for i in np.arange(n):
        r, g, b = np.zeros(3)

        for j in np.arange(8):
            r = r + (1<<(7-j))*((i&(1<<(3*j))) >> (3*j))
            g = g + (1<<(7-j))*((i&(1<<(3*j+1))) >> (3*j+1))
            b = b + (1<<(7-j))*((i&(1<<(3*j+2))) >> (3*j+2))

        cmap[i,:] = np.array([r, g, b])

    return cmap


def Relabel(olabel,nlabel,tensor):
  assert isinstance(tensor, torch.LongTensor), 'tensor needs to be LongTensor'
  tensor[tensor == olabel] = nlabel
  return tensor


def ToLabel(image):
  return torch.from_numpy(np.array(image)).long().unsqueeze(0)


class Colorize:

    def __init__(self, n=22):
        self.cmap = colormap(256)
        self.cmap[n] = self.cmap[-1]
        self.cmap = torch.from_numpy(self.cmap[:n])

    def __call__(self, gray_image):
        size = gray_image.size()
        color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0)

        for label in range(1, len(self.cmap)):
            mask = gray_image == label

            color_image[0][mask] = self.cmap[label][0]
            color_image[1][mask] = self.cmap[label][1]
            color_image[2][mask] = self.cmap[label][2]

        return color_image

In [None]:
class MyDataSet():

  def __init__(self,image_path,mask_path):
    self.image_paths = image_path
    self.target_paths = mask_path
    self.files = sorted(os.listdir(self.image_paths))
    self.lables = sorted(os.listdir(self.target_paths))
  def transform(self, image, mask):

    resize = transforms.Resize(size=(224, 224))
    image = resize(image)
    mask = resize(mask)

    normaliz=transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                  std=[0.229, 0.224, 0.225])

    # Transform to tensor
    image = TF.to_tensor(image)

    image=normaliz(image)

    mask=ToLabel(mask)
    mask=Relabel(255,21,mask)

    return image, mask

  
  def __getitem__(self, index):
    img_name = self.files[index]
    label_name = self.lables[index]
    image = Image.open(os.path.join(self.image_paths,img_name))
    mask = Image.open(os.path.join(self.target_paths,label_name))
    x, y = self.transform(image, mask)
    return x,y

  def __len__(self):
      return len(self.files)



In [None]:
#Add Data
train_data=MyDataSet("/content/drive/My Drive/Herlev_Maps/Train/imgs","/content/drive/My Drive/Herlev_Maps/Train/masks")
val_data=MyDataSet("/content/drive/My Drive/Herlev_Maps/Validation/imgs","/content/drive/My Drive/Herlev_Maps/Validation/masks")
test_data=MyDataSet("/content/drive/My Drive/Herlev_Maps/Test/imgs","/content/drive/My Drive/Herlev_Maps/Test/masks")

In [None]:
train_load=torch.utils.data.DataLoader(train_data,batch_size=20,shuffle=True)
val_load=torch.utils.data.DataLoader(val_data,batch_size=20,shuffle=True)
test_load=torch.utils.data.DataLoader(test_data,batch_size=20,shuffle=True)

In [None]:
class CrossEntropyLoss2d(nn.Module):

    def __init__(self, weight=None):
        super(CrossEntropyLoss2d, self).__init__()

        self.loss = nn.NLLLoss(weight)

    def forward(self, outputs, targets):
        return self.loss(F.log_softmax(outputs, dim = 1), targets[:,0,:,:])

In [None]:
def save_checkpoint(state, is_best, filename='/content/drive/My Drive/DataSet/PSPnet:mark1.pth'):
    """Save checkpoint if a new best is achieved"""
    if is_best:
        print ("=> Saving a new best")
        torch.save(state, filename)

In [None]:
def load_checkpoint(checkpoint_fpath,model):
  checkpoint=torch.load(checkpoint_fpath)
  model.load_state_dict(checkpoint['state_dict'])
  return model

In [None]:
#Accuracy Metrics

def pixel_accuracy(image1,image2):
	image1=np.array(image1)
	image2=np.array(image2)
	[row,col]=image1.shape
	image1=np.reshape(image1,(row*col,1))
	image2=np.reshape(image2,(row*col,1))
	count=0
	total_count=0
	for i in range(row*col):
			total_count+=1
			if(image1[i]==image2[i]):
				count+=1

	return count/(total_count + 1e-8)

def mean_accuracy(image1,image2,num_classes):
	image1=np.array(image1)
	image2=np.array(image2)
	[row,col]=image1.shape
	correct_labels=np.zeros((num_classes,1))
	incorrect_labels=np.zeros((num_classes,1))
	image1=np.reshape(image1,(row*col,1))
	image2=np.reshape(image2,(row*col,1))
	for i in range(row*col):
		if(image1[i]==image2[i]):
			correct_labels[image2[i]]+=1
		else:
			incorrect_labels[image2[i]]+=1
	return ((sum(correct_labels/(correct_labels+incorrect_labels+1e-8)))[0]/sum((correct_labels+incorrect_labels)>0)[0]);


def mean_IU(image1,image2,num_classes, ignore_index = None):
	image1=np.array(image1)
	image2=np.array(image2)
	[row,col]=image1.shape
	correct_predictions=np.zeros((num_classes,1))
	incorrect_predictions=np.zeros((num_classes,1))
	correct_labels=np.zeros((num_classes,1))
	incorrect_labels=np.zeros((num_classes,1))
	image1=np.reshape(image1,(row*col,1))
	image2=np.reshape(image2,(row*col,1))

	for i in range(row*col):
		if(image1[i]==image2[i]):
			correct_predictions[image1[i]]+=1
			correct_labels[image1[i]]+=1
		else:
			incorrect_predictions[image1[i]]+=1
			incorrect_labels[image2[i]]+=1
	if(ignore_index):
		for i in ignore_index:
			correct_predictions[i] = 0
			incorrect_predictions[i] = 0
			incorrect_labels[i] = 0
	return ((sum(correct_predictions/(correct_predictions+incorrect_predictions+incorrect_labels+1e-8)))[0]
			/(num_classes - len(ignore_index)))


In [None]:
def initialize_weights(*models):
	for model in models:
		for module in model.modules():
			if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
				nn.init.kaiming_normal(module.weight)
				if module.bias is not None:
					module.bias.data.zero_()
			elif isinstance(module, nn.BatchNorm2d):
				module.weight.data.fill_(1)
				module.bias.data.zero_()

In [None]:
class PyramidPool(nn.Module):

	def __init__(self, in_features, out_features, pool_size):
		super(PyramidPool,self).__init__()

		self.features = nn.Sequential(
			nn.AdaptiveAvgPool2d(pool_size),
			nn.Conv2d(in_features, out_features, 1, bias=False),
			nn.BatchNorm2d(out_features, momentum=.95),
			nn.ReLU(inplace=True)
		)


	def forward(self, x):
		size=x.size()
		# output=F.upsample(self.features(x), size[2:], mode='bilinear')
		output=F.interpolate(self.features(x), size[2:], mode='bilinear',align_corners=True)
		return output


In [None]:
# PSP_Net
class PSPNet(nn.Module):

    def __init__(self, num_classes, pretrained = True):
        super(PSPNet,self).__init__()
        print("initializing model")
        #init_net=deeplab_resnet.Res_Deeplab()
        #state=torch.load("models/MS_DeepLab_resnet_trained_VOC.pth")
        #init_net.load_state_dict(state)
        self.resnet = torchvision.models.resnet50(pretrained = True)


        self.layer5a = PyramidPool(2048, 512, 1)
        self.layer5b = PyramidPool(2048, 512, 2)
        self.layer5c = PyramidPool(2048, 512, 3)
        self.layer5d = PyramidPool(2048, 512, 6)




        self.final = nn.Sequential(
        	nn.Conv2d(4096,num_classes, 3, padding=1, bias=True),
        )


        initialize_weights(self.layer5a,self.layer5b,self.layer5c,self.layer5d,self.final)




    def forward(self, x):
        count=0

        size=x.size()
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        # x = self.resnet.maxpool(x)
        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)
        # x = self.resnet.avgpool(x)

        x = self.final(torch.cat([
        	x,
        	self.layer5a(x),
        	self.layer5b(x),
        	self.layer5c(x),
        	self.layer5d(x),
        ], 1))


        # return F.upsample_bilinear(x,size[2:])
        return F.interpolate(x,size[2:])

In [None]:
device='cuda'

In [None]:
num_classes = 5
weights = torch.ones(num_classes)
weights[0]=0
weights=weights.to(device='cuda')


model = PSPNet(num_classes=5, pretrained=True)
model.to(device)


In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-4weight_decay=1e-4)
criterion=CrossEntropyLoss2d(weights)

In [None]:
epochs=20
# model=load_checkpoint('/content/drive/My Drive/DataSet/PSPnet:mark1.pth',model)
trainloss_data=[]
accuracy_data=[]
valloss_data=[]
best_acc=0
best_val_IU=0
for epoch in range(epochs):
  train_loss=0
  val_loss=0
  accuracy=0
  model.train()
  counter=0
  print("Epoch : ",epoch+1)
  for inputs,target in train_load: 
    counter+=1
    inputs,target=inputs.to(device),target.to(device)
    optimizer.zero_grad()
    output = model(inputs)
    loss = criterion(output,target)
    loss.backward()
    optimizer.step()
    print(counter)
    train_loss+=loss*inputs.size(0)
  train_acc = [pixel_accuracy(output[i].cpu().data.max(0)[1].detach(),target[i,0,:,:].cpu().data.detach()) for i in range(len(output))]
  train_acc = sum(train_acc) / len(train_acc)
  train_IU = [mean_IU(output[i].cpu().data.max(0)[1].detach(),target[i,0,:,:].cpu().data.detach(), num_classes,  ignore_index = [num_classes - 1]) for i in range(len(output))]
  train_IU = sum(train_IU) / len(train_IU)
  train_loss=train_loss.cpu().data.detach().item()
  print("Train Accuracy : ",train_acc)
  print("Train Loss : ",train_loss/len(train_load.dataset))
  print("Train IU : ",train_IU)

  model.eval()
  with torch.no_grad():
    counter=0
    val_loss=0
    for inputs,target in val_load:
      counter+=1
      print(counter)
      inputs,target=inputs.to(device),target.to(device)
      
      output = model(inputs)
      valloss = criterion(output, target)
      val_loss+=valloss*inputs.size(0)

    
    val_acc = [pixel_accuracy(output[i].cpu().data.max(0)[1].detach(),target[i,0,:,:].cpu().data.detach()) for i in range(len(output))]
    val_acc = sum(val_acc) / len(val_acc)
    val_IU = [mean_IU(output[i].cpu().data.max(0)[1].detach(),target[i,0,:,:].cpu().data.detach(), num_classes, ignore_index = [num_classes - 1]) for i in range(len(output))]
    val_IU = sum(val_IU) / len(val_IU)
    val_loss=val_loss.cpu().data.detach().item()

    is_best = bool(best_val_IU<val_IU)
    
    print(is_best)
    if is_best==True:
      best_val_IU=val_IU
      save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
      }, is_best)
  
    print("Validation Accuracy : ",val_acc)
    print("Validation Loss : ",val_loss/len(val_load.dataset))
    print("Validation IU : ",val_IU)

  print("\n")    
  


  # trainloss_data.append(float('{:.3f}'.format(train_loss)))
  # valloss_data.append(float('{:.3f}'.format(val_loss)))
  # accuracy_data.append(float('{:.4f}'.format(accuracy)))
  # print("Accuracy : {:.6f}".format(accuracy))
  # print('Training Loss: {:.6f} \tValidation Loss: {:.6f}'.format(train_loss, val_loss))

In [None]:
model=load_checkpoint('/content/drive/My Drive/DataSet/PSPnet:mark1.pth',model)
model.eval()
with torch.no_grad():
  counter=0
  test_loss=0
  for inputs,target in test_load:
    counter+=1
    print(counter)
    inputs,target=inputs.to(device),target.to(device)
    
    output = model(inputs)
    testloss =  criterion(output, target)
    test_loss+=testloss*inputs.size(0)

test_acc = [pixel_accuracy(output[i].cpu().data.max(0)[1].detach(),target[i,0,:,:].cpu().data.detach()) for i in range(len(output))]
test_acc = sum(test_acc) / len(test_acc)
test_IU = [mean_IU(output[i].cpu().data.max(0)[1].detach(),target[i,0,:,:].cpu().data.detach(), num_classes, ignore_index = [num_classes - 1]) for i in range(len(output))]
test_IU = sum(test_IU) / len(test_IU)
test_loss=test_loss.cpu().data.detach()
print("Test Accuracy : ",test_acc)
print("Test Loss : ",test_loss/len(test_load.dataset))
print("Test IU : ",test_IU)


In [None]:
dir="/content/drive/My Drive/Herlev_Maps/Train/imgs"
ref_image=Image.open("/content/drive/My Drive/Herlev_Maps/Train/masks/148494967-148494986-001-d.bmp")

input_transform = transforms.Compose([
	transforms.Resize(size=(224, 224)),
	transforms.ToTensor(),
	transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                      std=[0.229, 0.224, 0.225])])


im1 = input_transform(Image.open("/content/drive/My Drive/Herlev_Maps/Train/imgs/148494967-148494986-001.BMP").convert('RGB')) #240,240
input_transform1 = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor()
])
color_transform = Colorize()
image_transform = transforms.ToPILImage()
im2=input_transform1(Image.open("/content/drive/My Drive/Herlev_Maps/Train/imgs/148494967-148494986-001.BMP"))
im1=im1.to(device)
label = model(Variable(im1, volatile=True).unsqueeze(0))#1,N,240,240
label = color_transform(label[0].data.max(0)[1])#1,3,240,240
output = image_transform(label)
output = output.quantize(palette=ref_image)
# output.save(label_name)
output.show()
im2=image_transform(im2)
im2.show()
