In [1]:
import os
import cv2
from PIL import Image
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from processer import equalize, zmIceColor, extractGreen
datasets_path = "..\\dataset\\classifer\\train"

In [2]:
class PlantSeedDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        PlantSeedDataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform, 数据预处理
        """
        self.data_info = self.get_img_info(data_dir)  
        self.transform = transform
        self.name_dic = {'Black-grass': 0, 'Charlock': 1, 'Cleavers': 2, 'Common Chickweed': 3, 
                'Common wheat': 4, 'Fat Hen': 5, 'Loose Silky-bent': 6, 'Maize': 7,
                'Scentless Mayweed': 8, 'Shepherds Purse': 9, 'Small-flowered Cranesbill': 10, 'Sugar beet': 11}
 
    def __getitem__(self, index):
        path_img, label, img_name = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     
 
        if self.transform is not None:
            img = self.transform(img)   
             
        return img, label, img_name
 
    def __len__(self):
        return len(self.data_info)
 
    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.png'), img_names))
 
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = sub_dir
                    data_info.append((path_img, label, img_name))
 
        return data_info

In [3]:
transform = transforms.Compose([
        transforms.Resize((256, 256)),
        np.asarray,
        transforms.Lambda(lambda x: cv2.cvtColor(x, cv2.COLOR_RGB2BGR))
        ])
dataset = PlantSeedDataset(datasets_path, transform=transform)
dataloader = DataLoader(dataset)

In [4]:
from skimage.feature import hog
def processer(img):
    # 颜色增强
    ace_img = zmIceColor(img/255.0, ratio=4, radius=3)
    ace_img  = ace_img * 255
    ace_img = np.uint8(ace_img)
    # 抽绿色
    gre_img = extractGreen(ace_img, 7)
    return gre_img

def SIFT(gre_img):
    # 创建SIFT特征检测器
    sift_after = cv2.SIFT_create()
    # 特征点提取与描述子生成
    kp, des = sift_after.detectAndCompute(gre_img, None)
    return kp, des

# 初始化BOW训练器
def bow_init(feature_sift_list):
    # 100类
    bow_kmeans_trainer = cv2.BOWKMeansTrainer(100)
    
    for feature_sift in feature_sift_list:
        if type(feature_sift) == type(None):
            continue
        # print(feature_sift.shape)
        bow_kmeans_trainer.add(feature_sift)
    
    # 进行k-means聚类，返回词汇字典 也就是聚类中心
    voc = bow_kmeans_trainer.cluster()
    
    # 输出词汇字典
    # print(voc)
    # print(type(voc),voc.shape)
    
    # FLANN匹配  
    # algorithm用来指定匹配所使用的算法，可以选择的有LinearIndex、KTreeIndex、KMeansIndex、CompositeIndex和AutotuneInde
    # 这里选择的是KTreeIndex(使用kd树实现最近邻搜索)
    flann_params = dict(algorithm=1,tree=5)
    flann = cv2.FlannBasedMatcher(flann_params,{})
    
    # print(flann)
    
    #初始化bow提取器(设置词汇字典),用于提取每一张图像的BOW特征描述
    sift = cv2.SIFT_create()
    bow_img_descriptor_extractor = cv2.BOWImgDescriptorExtractor(sift, flann)        
    bow_img_descriptor_extractor.setVocabulary(voc)
    
    # print(bow_img_descriptor_extractor)
    
    return bow_img_descriptor_extractor

# 提取BOW特征
def bow_feature(bow_img_descriptor_extractor, image_list):
    # 分别对每个图片提取BOW特征，获得BOW特征列表
    feature_bow_list = [] 
    sift = cv2.SIFT_create()
    for i in range(len(image_list)):
        image = cv2.cvtColor(image_list[i], cv2.COLOR_BGR2GRAY)
        feature_bow = bow_img_descriptor_extractor.compute(image,sift.detect(image))
        feature_bow_list.append(feature_bow)
    return feature_bow_list


def HOG(gre_img):
    gray_img = cv2.cvtColor(gre_img,cv2.COLOR_BGR2GRAY)
    hog_feature, hog_img = hog(gray_img, orientations=9, 
                               pixels_per_cell=(6, 6), 
                               cells_per_block=(3, 3), 
                               visualize=True, feature_vector=True)
    return hog_feature, hog_img

# 局部二值特征
def LBP(gre_img):
    from skimage.feature import local_binary_pattern
    method = 'var'
    n_points = 8
    radius = 4
    b, g, r = cv2.split(gre_img)
    b = local_binary_pattern(b, n_points, radius, method)
    g = local_binary_pattern(g, n_points, radius, method)
    r = local_binary_pattern(r, n_points, radius, method)
    feature_lbp = cv2.merge((b, g, r))
    return feature_lbp

In [81]:
for img, label, img_name in dataset:
    print(type(img), type(label), type(img_name))
    break

<class 'numpy.ndarray'> <class 'str'> <class 'str'>


In [56]:
import pickle as pkl
gre_imgs = []
sift_features = []
for img, label, img_name in dataset:
    gre_img = processer(img)
    kp, sift_feature = SIFT(gre_img)
    gre_imgs.append(gre_img)
    sift_features.append(sift_feature)

with open("sift_features.pkl", "wb") as f:  # 序列化
    pkl.dump(sift_features, f) 
with open("gre_img.pkl", "wb") as f:
    pkl.dump(gre_imgs, f)
len(sift_features)

4440

In [62]:
with open('gre_img.pkl', 'rb') as f:
    gre_imgs = pkl.load(f)
sift_features = []
for img in gre_imgs:
    kp, sift_feature = SIFT(img)
    sift_features.append(sift_feature)
len(sift_features), len(gre_imgs)

4440

In [5]:
import pickle as pkl
with open('gre_img.pkl', 'rb') as f:
    gre_imgs = pkl.load(f)
with open('sift_features.pkl', 'rb') as f:
    sift_features = pkl.load(f)

In [6]:
bow_extractor = bow_init(sift_features)
all_feature_bow = bow_feature(bow_extractor, gre_imgs)

In [7]:
len(all_feature_bow)

4440

In [8]:
with open("all_feature_bow.pkl", "wb") as f:  # 序列化
    pkl.dump(all_feature_bow, f) 

In [10]:
all_feature_bow[0].shape

(1, 100)

In [11]:
hog_features = []
for img in gre_imgs:
    hog_feature, hog_img = HOG(img)
    hog_features.append(hog_feature)
with open("hog_features.pkl", "wb") as f:  # 序列化
    pkl.dump(hog_features, f) 

KeyboardInterrupt: 

In [None]:
feature_lbps = []
for img in gre_imgs:
    feature_lbp = LBP(img)
    feature_lbps.append(feature_lbp)
with open("feature_lbps.pkl", "wb") as f:  # 序列化
    pkl.dump(feature_lbps, f) 