In [1]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


Paper 1: [UNSUPERVISED IMAGE SEGMENTATION BY BACKPROPAGATION](https://kanezaki.github.io/pytorch-unsupervised-segmentation/ICASSP2018_kanezaki.pdf) (ICAASP 2018)

In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import cv2
import sys
import numpy as np
from skimage import segmentation
import torch.nn.init
from google.colab.patches import cv2_imshow

In [3]:
use_cuda = torch.cuda.is_available()
use_cuda

True

In [87]:
CONFIG = {
    "nChannel":100,
    "maxIter":1000,
    "minLabels":50,
    "lr":0.1,
    "nConv":2,
    "num_superpixels":10000,
    "compactness":10
}

# CNN Model

In [13]:
class MyNet(nn.Module):
    def __init__(self,input_dim):
        super(MyNet, self).__init__()

        ##conv1 and conv2 are the M feature extractors
        self.conv1 = nn.Conv2d(input_dim, CONFIG["nChannel"], kernel_size=3, stride=1, padding=1 )
        self.bn1 = nn.BatchNorm2d(CONFIG["nChannel"])
        self.conv2 = nn.ModuleList()
        self.bn2 = nn.ModuleList()
        for i in range(CONFIG["nConv"]-1):
            self.conv2.append( nn.Conv2d(CONFIG["nChannel"], CONFIG["nChannel"], kernel_size=3, stride=1, padding=1 ) )
            self.bn2.append( nn.BatchNorm2d(CONFIG["nChannel"]) )

        ## Now the p channels are reduced to q (here p=q so nChannel only)
        self.conv3 = nn.Conv2d(CONFIG["nChannel"], CONFIG["nChannel"], kernel_size=1, stride=1, padding=0 )
        self.bn3 = nn.BatchNorm2d(CONFIG["nChannel"])

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu( x )
        x = self.bn1(x)
        for i in range(CONFIG["nConv"]-1):
            x = self.conv2[i](x)
            x = F.relu( x )
            x = self.bn2[i](x)
        x = self.conv3(x)
        x = self.bn3(x)
        return x

# Train

In [88]:
# train
def train(image, l_inds):
    model = MyNet(image.size(1))
    if use_cuda:
        model.cuda()
    model.train()
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=CONFIG["lr"], momentum=0.9)
    
    for batch_idx in range(CONFIG["maxIter"]):
        # forward pass to get the class labels
        optimizer.zero_grad()
        output = model(image)[ 0 ]
        output = output.permute( 1, 2, 0 ).contiguous().view( -1, CONFIG["nChannel"] )  #flattened out for every pixel
        
        #argmax step
        ignore, target = torch.max( output, 1 ) #size is (H*W,1)
        im_target = target.data.cpu().numpy()
        
        # if args.visualize:
        #     im_target_rgb = np.array([label_colours[ c % 100 ] for c in im_target])
        #     im_target_rgb = im_target_rgb.reshape( im.shape ).astype( np.uint8 )
        #     cv2.imshow( "output", im_target_rgb )
        #     cv2.waitKey(10)


        # superpixel refinement
        # TODO: use Torch Variable instead of numpy for faster calculation
        for i in range(len(l_inds)):
            labels_per_sp = im_target[ l_inds[ i ] ]
            u_labels_per_sp = np.unique( labels_per_sp )
            hist = np.zeros( len(u_labels_per_sp) )
            for j in range(len(hist)):
                hist[ j ] = len( np.where( labels_per_sp == u_labels_per_sp[ j ] )[ 0 ] )
            im_target[ l_inds[ i ] ] = u_labels_per_sp[ np.argmax( hist ) ]

        nLabels = len(np.unique(im_target))
        target = torch.from_numpy( im_target )
        if use_cuda:
            target = target.cuda()

        if nLabels <= CONFIG["minLabels"]:
            print(np.unique(im_target))
            print(len(np.unique(im_target)))
            print ("nLabels", nLabels, "reached minLabels", CONFIG["minLabels"], ".")
            classes = np.unique(im_target)
            print(classes)
            for j in classes:
                res = np.array([0 if c==j else 255 for c in im_target])
                res_rgb = res.reshape((1000,1000)).astype( np.uint8 )
                cv2.imwrite(os.path.join("/content/result","{}.png".format(j)), res_rgb)
            break

        target = Variable( target )
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        print (batch_idx, '/', CONFIG["maxIter"], ':', nLabels, loss.item())

        
        #print (batch_idx, '/', args.maxIter, ':', nLabels, loss.data[0])

    # model.eval()  
    # output = model(data)[ 0 ]
    # output = output.permute( 1, 2, 0 ).contiguous().view( -1, CONFIG["nChannel"])
    # ignore, target = torch.max( output, 1 )
    # im_target = target.data.cpu().numpy()

    # classes = np.unique(im_target)
    # print(classes)
    # for j in classes:
    #     res = np.array([0 if c==j else 255 for c in im_target])
    #     res_rgb = res.reshape((1000,1000)).astype( np.uint8 )
    #     cv2.imwrite(os.path.join("/content/result","{}.png".format(j)), res_rgb)

    torch.save(model.state_dict(), "/content/model.pt")
    
    # return model

# Load Image

In [90]:
!rm -rf /content/suppixel
!rm -rf /content/result
!mkdir /content/suppixel
!mkdir /content/result

In [91]:
DIR = "/content/drive/MyDrive/BTP/100Docbank"
l = sorted(os.listdir(DIR))
np.random.seed(100)
for i in range(1):
    idx = np.random.randint(0,len(l))
    im = cv2.imread(os.path.join(DIR, l[idx]))
    print(im.shape)
    im = cv2.resize(im,(1000,1000))
    print(im.shape)

    # cv2_imshow(im)
    data = torch.from_numpy( np.array([im.transpose( (2, 0, 1) ).astype('float32')/255.]) )
    if use_cuda:
        data = data.cuda()
    data = Variable(data)
    #data is of form c*h*w

    labels = segmentation.slic(im, compactness=CONFIG["compactness"], n_segments=CONFIG["num_superpixels"])
    labels = labels.reshape(im.shape[0]*im.shape[1])
    u_labels = np.unique(labels)
    l_inds = []
    for j in range(len(u_labels)):
        l_inds.append( np.where( labels == u_labels[ j ] )[ 0 ] )
    # print(len(l_inds))
    # for j in range(len(l_inds)):
    #     supim = np.zeros(labels.shape)
    #     for k in l_inds[j]:
    #         # print(j)
    #         supim[k]=255
    #     supim = supim.reshape(im.shape[:2])
    #     # print(supim.shape)
    #     # print(np.unique(supim))
    #     cv2.imwrite("/content/suppixel/{}.png".format(j), supim)
    #     # break
    train(data, l_inds)

    #load model
    # model = MyNet(data.size(1))
    # model.load_state_dict(torch.load("/content/model.pt")) 
    # model.eval()
    # if use_cuda:
    #     model = model.cuda()
    # label_colours = np.random.randint(255,size=(100,3))

    # im = cv2.imread(os.path.join(DIR, l[idx]))
    # print(im.shape)
    # im = cv2.resize(im,(1000,1000))
    # print(im.shape)

    # # cv2_imshow(im)
    # data = torch.from_numpy( np.array([im.transpose( (2, 0, 1) ).astype('float32')/255.]) )
    # if use_cuda:
    #     data = data.cuda()
    # data = Variable(data)

    # output = model(data)[ 0 ]
    # output = output.permute( 1, 2, 0 ).contiguous().view( -1, CONFIG["nChannel"])
    # ignore, target = torch.max( output, 1 )
    # im_target = target.data.cpu().numpy()

    # classes = np.unique(im_target)
    # print(classes)
    # for j in classes:
    #     res = np.array([0 if c==j else 255 for c in im_target])
    #     res_rgb = res.reshape((1000,1000)).astype( np.uint8 )
    #     cv2.imwrite(os.path.join("/content/result","{}.png".format(j)), res_rgb)
    #     # break
    # # im_target_rgb = np.array([label_colours[ c % 100 ] for c in im_target])
    # # im_target_rgb = im_target_rgb.reshape( im.shape ).astype( np.uint8 )

    # # cv2_imshow(im_target_rgb)

(2339, 1654, 3)
(1000, 1000, 3)




[ 0  2  3  6  7  8 11 12 13 16 25 32 38 39 40 47 50 51 54 58 60 65 68 71
 76 79 86 90 91 92 93]
31
nLabels 31 reached minLabels 50 .
[ 0  2  3  6  7  8 11 12 13 16 25 32 38 39 40 47 50 51 54 58 60 65 68 71
 76 79 86 90 91 92 93]


In [51]:
!ls /content/suppixel