In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import glob
import os
import numpy as np
from PIL import Image
import torchvision.models as models
import copy
from torchvision.utils import save_image
import PIL
import skimage.io
import copy
from scipy import stats
import multiprocessing as mp

In [2]:
#customized class
def fnToIndex(fn):
    return int(fn.split('/')[-1].split('_')[0])

def getImg(fn):
    if fnToIndex(fn) % 100 == 1:
        print(fn)
    trans = transforms.ToTensor()
    return trans(Image.open(fn)).numpy()

def voteByNeighbor(mask, noiseLabels, coordinate):
    i, j = coordinate
    toClass = [6,4,3,0,7,2,1,5]
    if mask[i][j] not in noiseLabels:
        return toClass[mask[i][j]]
    voteRange = 4
    iMin, iMax = max(0, i-voteRange), min(512, i+voteRange+1)
    jMin, jMax = max(0, j-voteRange), min(512, j+voteRange+1)
    voteList = [mask[i][j] for i in range(iMin, iMax) for j in range(jMin, jMax) if mask[i][j] not in noiseLabels]
    toClass = [6,4,3,0,7,2,1,5]
    return toClass[stats.mode(voteList)[0][0]]
    
def getMask(fn):
    mask = (skimage.io.imread(fn) >= 128).astype(int)
    mask = 4*mask[:, :, 0] + 2*mask[:, :, 1] + mask[:, :, 2]
    classCnt = np.zeros(8)
    for i in range(512):
        for j in range(512):
            classCnt[mask[i][j]] += 1

    noiseLabels = [i for i in range(8) if classCnt[i] < 20 and classCnt[i] > 0]
    mask = [voteByNeighbor(mask,noiseLabels,(i,j)) for i in range(512) for j in range(512)]
    return np.reshape(mask, (512,512))

def getMaskTuple(args):
    fnImg, fnMask = args
    return getImg(fnImg), getMask(fnMask)
    
def loadDataSet(root):
    print('loading data at',root)
    pool = mp.Pool(mp.cpu_count())
    fnImgList = sorted(glob.glob(os.path.join(root, '*.jpg')))
    fnMaskList = sorted(glob.glob(os.path.join(root, '*.png')))
    fnList = [(fnImgList[i], fnMaskList[i]) for i in range(len(fnImgList))]
    results = pool.map(getMaskTuple, [(fnImg, fnMask) for fnImg, fnMask in fnList])
    print('data loaded at',root)
    return results

In [3]:
#the longest journey
trainSet = loadDataSet('hw2_data/p2_data/train')
testSet = loadDataSet('hw2_data/p2_data/validation')

loading data at hw2_data/p2_data/train
hw2_data/p2_data/train/0001_sat.jpg
hw2_data/p2_data/train/0301_sat.jpg
hw2_data/p2_data/train/0101_sat.jpg
hw2_data/p2_data/train/0401_sat.jpg
hw2_data/p2_data/train/0201_sat.jpg
hw2_data/p2_data/train/0501_sat.jpg
hw2_data/p2_data/train/0801_sat.jpg
hw2_data/p2_data/train/0601_sat.jpg
hw2_data/p2_data/train/0901_sat.jpg
hw2_data/p2_data/train/0701_sat.jpg
hw2_data/p2_data/train/1001_sat.jpg
hw2_data/p2_data/train/1101_sat.jpg
hw2_data/p2_data/train/1401_sat.jpg
hw2_data/p2_data/train/1201_sat.jpg
hw2_data/p2_data/train/1501_sat.jpg
hw2_data/p2_data/train/1301_sat.jpg
hw2_data/p2_data/train/1601_sat.jpg
hw2_data/p2_data/train/1901_sat.jpg
hw2_data/p2_data/train/1701_sat.jpg
hw2_data/p2_data/train/1801_sat.jpg
data loaded at hw2_data/p2_data/train
loading data at hw2_data/p2_data/validation
hw2_data/p2_data/validation/0001_sat.jpg
hw2_data/p2_data/validation/0101_sat.jpg
hw2_data/p2_data/validation/0201_sat.jpg
data loaded at hw2_data/p2_data/vali

In [4]:
def randomCheckData(train_set, checkCnt):
    arr = np.arange(2000)
    np.random.shuffle(arr)
    toColor = ['sky blue','yellow','pink','green','deep blue','white','black']
    for index in arr[:checkCnt]:
        img, mask = trainSet[index]
        classCnt = np.zeros(8)
        for i in range(512):
            for j in range(512):
                classCnt[mask[i][j]] += 1
        existingLabels = [i for i in range(8) if classCnt[i]]
        colorList = [toColor[label] for label in existingLabels]
        print(index, colorList)
        
def saveDataSet(root,dataSet):
    for i, (img, mask) in enumerate(dataSet):
        np.save(root+"/img_"+str(i)+".npy", img)
        np.save(root+"/mask_"+str(i)+".npy", mask)
    print('data saving completed at',root)
    
randomCheckData(trainSet, 5)

1763 ['green', 'deep blue', 'black']
808 ['yellow', 'green']
1864 ['sky blue', 'yellow', 'pink', 'green']
1781 ['sky blue', 'yellow', 'pink']
238 ['yellow', 'pink', 'deep blue', 'white']


In [5]:
saveDataSet('hw2_data/p2_data/train_npy',trainSet)
saveDataSet('hw2_data/p2_data/validation_npy',testSet)

data saving completed at hw2_data/p2_data/train_npy
data saving completed at hw2_data/p2_data/validation_npy
