<a href="https://colab.research.google.com/github/aksl20/safran-automl/blob/dataset/create_texture_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# New texture images dataset

This dataset relies on  the anomaly detection  dataset proposed by MVTEC, and use textile texture from the carpet dataset. 
Originally, the main goal of the dataset is to design anomaly detection algorithm using only sample with defects. For experimental purposes, we propose to use to use this dataset as a classification task using patches sampled in the test data, using this classes :
 * good
 * color
 * cut
 * good
 * hole
 * thread
 * metal_contamination


## Dataset creation methodology

As the test set of the original dataset contains few defect samples, we propose a simple and quick  method to add artificial samples using simple transformation: rotations and flips.

The classification dataset is created different configurations as follows:
* For each class generate described above, we generate images using a rotation angle of the patch from 0 to 150 degree.

This results in 40 different combination that are all used and stored

This is done for all patches, such that for good samples we only keep 100 samples per parameter configurations, and for defect images we keep samples where at least  10% of the patch contains a defetc zone(according to the mask of defect images)


In [1]:
import os, glob
import numpy as np
import torch
from scipy.ndimage import map_coordinates
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
from skimage import io
from skimage.color import rgb2gray
from skimage.transform import resize
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


  import pandas.util.testing as tm


In [0]:
root = "/content/SSL-research" #@param {type:"string"}
dataset_name = "carpet" #@param ["carpet"]
data_root = os.path.join(root,'www/')
dataset_path = os.path.join(data_root, dataset_name)
if os.path.exists(os.path.join(data_root,dataset_name)) and len(os.listdir(data_root))>0:
    pass
else:
    url = f'ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/{dataset_name}.tar.xz'
    !wget -N $url
    tarfile = f'{dataset_name}.tar.xz'
    !tar -xJf  $tarfile
    !rm  -rf $tarfile
    if not os.path.exists(data_root): os.makedirs(data_root)
    path_data = f'/content/{dataset_name}'
    print(data_root)
    !mv $path_data "$data_root"
    


In [0]:
from torch.utils.data import Dataset, DataLoader
from skimage import io
from skimage.color import rgb2gray
from skimage.transform import resize
import h5py
from datetime import datetime
from tqdm import tqdm
import pandas as pd
from torchvision import transforms,utils

import contextlib


In [0]:

class PatchManager(object):
    def __init__(self,img_size, patch_shape=None, stride=None, centers_list=None):
        self.img_size = img_size
        assert (patch_shape[0]%2==0) and (patch_shape[1]%2==0), "the patch shapes must be define with even number"
        self.patch_shape = patch_shape if patch_shape is not None else (64,64)
        self.h,self.w = self.patch_shape
        self.stride = stride if stride is not None else (32,32)
        self.center_list = centers_list if isinstance(centers_list, np.ndarray) else PatchManager.get_center_lists_grid(img_shape=self.img_size, stride=stride)
        self.fh,self.fw = int(np.floor(self.h/2)), int(np.floor(self.w/2))
    @staticmethod
    def get_center_lists_grid(img_shape,stride):
        x = np.linspace(stride[0], img_shape[0]-stride[0],(img_shape[0]-stride[0])// stride[0])
        y = np.linspace(stride[1], img_shape[1]-stride[1],(img_shape[1]-stride[1])// stride[0])
        X,Y = np.meshgrid(x,y)
        centers = np.hstack((X.ravel().reshape(-1,1), Y.ravel().reshape(-1,1))).astype(int)
        return centers
    @staticmethod
    def get_rotation_matrix(angle):
        theta = np.radians(angle)
        c, s = np.cos(theta), np.sin(theta)
        return np.array(((c, -s), (s, c)))
    def get_patch(self, img, center_point, angle=0):
        center_row, center_col = center_point
        h,w = self.patch_shape
        ty = np.arange(center_col-self.fh, center_col+self.fh)
        tx = np.arange(center_row-self.fw, center_row+self.fw)
        rotatex, rotatey= np.meshgrid(tx,ty)
        coords   = [rotatex.reshape((1,h*w))-center_row, rotatey.reshape((1,h*w))-center_col]
        coords = np.asarray(coords).reshape(2,h*w)
        rotatedcoords = np.matmul(PatchManager.get_rotation_matrix(angle), coords)
        return map_coordinates(img, [rotatedcoords[1]+center_col, rotatedcoords[0]+center_row], order=1, mode='mirror').reshape(h,w)



In [0]:
class ImageDataset(Dataset):
    def __init__(self, root_folder, problem='carpet', datatype='train', img_size=None, patch_shape=None, stride=None,
                 test_example='hole', processed_root_folder=None,
                 sample_size=None, transform=None,download=True,angle=0,vflip=False,hflip=False,split=None):
        """
        # steps for initialization
        - first check that the root folder exists
        - check that image exists and make the list of imgs
        - create a PatchManager and generate centerlist
        """
        

        
        self.angle=angle
        self.split=split
        self.datatype = datatype
        self.img_size = (512, 512) if img_size is None else img_size
        self.patch_shape = (64, 64) if patch_shape is None else patch_shape
        self.stride = (40, 40) if stride is None else stride
        self.pm = PatchManager(img_size=self.img_size, patch_shape=self.patch_shape, stride=self.stride)
        self.process_data(datatype, self.img_size, self.patch_shape, problem, processed_root_folder, root_folder, self.stride,
                          test_example)
        # print(split)
        # print(self.data.shape)
        # print(self.mask.shape)
        self.data = self.data[:]
        self.mask = self.mask[:] if test_example != 'good' else None
        isplit_train = int(0.7 * self.data.shape[0])
        if split == 'train':
            self.data = self.data[:isplit_train]
            self.mask = self.mask[:isplit_train] if test_example != 'good' else None
        else:
            self.data = self.data[isplit_train:]
            self.mask = self.mask[isplit_train:] if test_example != 'good' else None
        # print(split,self.data.shape)
        # return
        # print(self.mask.shape)

        self.transform = transforms.Compose([ transforms.ToTensor()]) if transform is None else transform
        # randstate = 5000*(angle+1)
        # if split in ['train','test']:
        #     self.split_train_test(test_example=test_example,randstate=randstate)
        

        self.sample_size = sample_size
        self.resample_info()
        

    @contextlib.contextmanager
    def split_train_test(self, test_example,randstate):
        state = np.random.get_state()
    
        print(self.datatype, test_example,'split{}'.format(self.split))

        datatype = ['good', 'color', 'cut', 'good', 'hole', 'thread', 'metal_contamination']
        tmpdtype = 'good' if self.datatype=='train' else test_example
        i = self.data.shape[0]
        pos = datatype.index(tmpdtype)
        np.random.seed(pos+randstate)
        indexes = np.random.choice(np.arange(i), int(i*0.7),replace=False)
        trainindexes = np.sort(indexes)
        testindexes = np.sort(np.array([ii for ii in np.arange(i) if ii not in trainindexes]))
        if self.split=='train':
            self.data = self.data[:][trainindexes]
            self.mask = self.mask[:][trainindexes] if test_example !='good'  else None
        else:
            self.data = self.data[:][testindexes]
            self.mask = self.mask[:][testindexes] if test_example !='good'  else None

        np.random.set_state(state)


    def resample_info(self):
        self.dfinfo = self.create_sample_list()
        self.dfinfo = self.dfinfo.sample( self.sample_size if (self.sample_size is not None) and int(self.sample_size < self.dfinfo.shape[0]) else self.dfinfo.shape[
                0], replace=False)
        
    def process_data(self, datatype, img_size, patch_shape, problem, processed_root_folder, root_folder, stride,
                     test_example):
        # print(f'INFO: img_size={self.img_size}, patch_shape={self.patch_shape}, stride={self.stride}')
        processed_root_folder = processed_root_folder if processed_root_folder is not None else os.path.join(
            root_folder, 'processed', problem)
        processed_folder_path = os.path.join(processed_root_folder, '{}_{}'.format(*self.img_size), datatype,
                                             'good' if datatype == 'train' else test_example)
        if not os.path.exists(processed_folder_path): os.makedirs(processed_folder_path)
        h5_path = os.path.join(processed_folder_path, 'dataset.h5')
        if os.path.exists(h5_path):
            t = datetime.now()
            hdata = h5py.File(h5_path, mode='r')
            self.data = hdata['data']
            self.mask = hdata['mask']
            # print('Data Loading Time {}'.format(datetime.now() - t))
        else:
            t = datetime.now()
            # print('start processing the data')
            # maybe  create a folder to preprocess all images in order to speedup the reading
            assert datatype in ['train', 'test']
            if datatype == 'train':
                path_img = os.path.join(root_folder, problem, datatype, 'good')
                path_mask = None
            else:
                assert test_example in ['color', 'cut', 'good', 'hole', 'thread', 'metal_contamination']
                path_img, path_mask = os.path.join(root_folder, problem, datatype, test_example), os.path.join(
                    root_folder, problem, 'ground_truth', test_example)
                assert os.path.exists(path_mask)

            self.data, self.mask = self.read_images(glob.glob(os.path.join(path_img, '*.png')),
                                                    glob.glob(os.path.join(path_mask,
                                                                           '*.png')) if path_mask is not None else path_mask)
            with h5py.File(h5_path, mode='w') as hdata:
                hdata['data'] = self.data
                hdata['mask'] = self.mask
            print('Data Loading Time {}'.format(datetime.now() - t))
    def download(self,root, problem):
        pass

    def create_sample_list(self):
        df = pd.DataFrame()
        tmp = pd.DataFrame(dict(row=self.pm.center_list[:, 0], col=self.pm.center_list[:, 1]))
        for i in range(self.data.shape[0]):
            tmp['img_id'] = i
            df = df.append(tmp, ignore_index=True)
        return df

    def read_images(self, img_paths, mask_paths=None):
        data = []
        mask = []
        # if (mask_paths is not None) and (len(mask_paths)!=0): 
        parser = zip(img_paths, mask_paths) if (mask_paths is not None) and (
                    len(mask_paths) == len(img_paths)) else img_paths
        for ff in parser:
            try:
                fimg, fmask = ff
            except:
                fimg, fmask = (ff, None)
            data.append(resize(rgb2gray(io.imread(fimg)), self.img_size))  # check to do it with PIL instead
            if fmask is not None: mask.append(
                resize(rgb2gray(io.imread(fmask)), self.img_size))  # check to do it with PIL instead
        return np.asarray(data), np.asarray(mask)

    def __getitem__(self, index):
        ''' 
        # IMPORTANT
        # aa = self.data[img_id]
        # aa[col-20:col+20, row-15:row+15] = 0
        example to make a fake hole on the image // missing part (possible test to do  as self supervised method)
        '''
        # add une rotation aleatoire 
        # rotation = np.random.choice(np.arange(10)*15)
        img_id, row, col = self.read_image(index)
        img = self.pm.get_patch(img=self.data[img_id], center_point=(row, col), angle=self.angle ).reshape(*self.patch_shape).astype(np.float32)
        if self.datatype == 'test': 
            mask = self.pm.get_patch(img=self.mask[img_id], center_point=(row, col), angle=self.angle).reshape(*self.patch_shape).astype(np.float32)
        if self.transform is not None:
            timg = self.transform(img)
            if self.datatype == 'test': 
                tmask = self.transform(mask)

        if self.datatype == 'test':
            # return rotation value
            return timg, tmask#, rotation
        else:
            return timg

    def read_image(self, index):
        '''
        '''
        pass
        info = self.dfinfo.iloc[index] # Id image original, coordinate row_patch, coordinate col_patch
        img_id, row, col = info['img_id'], info['row'], info['col']
        return img_id, row, col

    def __len__(self):
        return self.dfinfo.shape[0]



In [19]:
def run(datatype,angle=0,vflip=False,hflip=False,split='train'):
    assert datatype in ['good', 'color', 'cut', 'good', 'hole', 'thread', 'metal_contamination']
    dataset = ImageDataset(root_folder='/content/SSL-research/www',
                           datatype='train' if datatype =='good' else 'test', img_size=(512,512),
                           patch_shape=(32,32), stride=(16,16),
                           test_example=datatype, sample_size=1000 if datatype =='good' else None,vflip=vflip,hflip=hflip, angle=angle,split=split)
    
    loader = DataLoader(dataset, shuffle=False, batch_size=128, num_workers=4, pin_memory=False)
    X = []
    for x  in loader:
        if datatype =='good':
            img = x
            X.append(img)
        else:
            img, mask = x
            mask = mask.reshape(mask.shape[0],-1).numpy()
            mask[mask>0] = 1
            i = np.where(mask.sum(axis=1)> 1 )[0]
            X.append(img[i])
            # return
            # grid_img = utils.make_grid(img,nrow=4).permute(1, 2, 0)
            # plt.imshow(grid_img)
    return np.concatenate(X)

def loop_run(split='test'):
    if os.path.exists('dataset.h5'):
        !rm -rf dataset.h5
    with h5py.File(f'DATASET_{split}.h5',mode='w') as f:
        for datatype in [ 'color', 'cut', 'good', 'hole', 'thread', 'metal_contamination']:
            print(datatype)
            group = f.create_group(datatype)
            for angle in [0,15,30,45,60,75, 90, 105, 120, 150]:
                print(angle)
                x = run( datatype=datatype, angle=angle, split=split)
                # return
                group['angle{}'.format(angle)] = x  
                # return
loop_run(split='train')

color
0
15
30
45
60
75
90
105
120
150
cut
0
15
30
45
60
75
90
105
120
150
good
0
15
30
45
60
75
90
105
120
150
hole
0
15
30
45
60
75
90
105
120
150
thread
0
15
30
45
60
75
90
105
120
150
metal_contamination
0
15
30
45
60
75
90
105
120
150


In [20]:
loop_run(split='test')

color
0
15
30
45
60
75
90
105
120
150
cut
0
15
30
45
60
75
90
105
120
150
good
0
15
30
45
60
75
90
105
120
150
hole
0
15
30
45
60
75
90
105
120
150
thread
0
15
30
45
60
75
90
105
120
150
metal_contamination
0
15
30
45
60
75
90
105
120
150
