Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
D-X-Y committed Mar 11, 2018
0 parents commit e7245dd
Show file tree
Hide file tree
Showing 42 changed files with 2,207 additions and 0 deletions.
16 changes: 16 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
data
*.pyc
*/*.pyc
*/*/*.pyc
.*.swp
*.pth.tar
*.png
test_utils/Generate_*.sh
./cache_data/*.lst
Temp*.sh
cache_data/lists
snapshots
cache_data/challenging.pdf
cache_data/common.pdf
cache_data/full.pdf
cache_data/cache/
56 changes: 56 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Style Aggregated Network for Facial Landmark Detection

We provide the training and testing codes for [SAN](https://d-x-y.github.io/publication/style-aggregation-network), implemented in [PyTorch](pytorch.org).

## Preparation

### Dependencies
- [Python3.6](https://www.anaconda.com/download/#linux)
- [PyTorch](http://pytorch.org/)
- [torchvision](http://pytorch.org/docs/master/torchvision)

### Datasets download
- Download 300W-Style and AFLW-Style from [Google Drive](https://drive.google.com/open?id=14f2lcJVF6E4kIICd8icUs8UuF3J0Mutd), and extract the downloaded files into `~/datasets/`.
- In 300W-Style and AFLW-Style directories, the `Original` sub-directory contains the original images from [300-W](https://ibug.doc.ic.ac.uk/resources/300-W/) and [AFLW](https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/aflw/)
<img src="cache_data/cache/dataset.jpg" width="480">
Figure 1. Our 300W-Style and AFLW-Style datasets. There are four styles, original, sketch, light, and gray.

# The Core Codes will come soon before 1st June.

### Generate lists for training and evaluation
```
cd cache_data
python aflw_from_mat.py
python generate_300W.py
```
The generated list file will be saved into `./cache_data/lists/300W` and `./cache_data/lists/AFLW`.

### Prepare images for the style-aggregated face generation module
```
python crop_pic.py
```
The above commands are used to pre-crop the face images.

## Training and Evaluation

### 300-W
- Step-1 : cluster images into different groups, for example `sh scripts/300W/300W_Cluster.sh 0,1 GTB 3`.
- Step-2 : use `sh scripts/300W/300W_CYCLE_128.sh 0,1 GTB` or `sh scripts/300W/300W_CYCLE_128.sh 0,1 DET` to train SAN on 300-W.

### AFLW
- Step-1 : cluster images into different groups, for example `sh scripts/AFLW/AFLW_Cluster.sh 0,1 GTB 3`.
- Step-2 : use `sh scripts/AFLW/AFLW_CYCLE_128.FULL.sh` or `sh scripts/AFLW/AFLW_CYCLE_128.FRONT.sh` to train SAN on AFLW.

## Citation
Please cite the following paper in your publications if it helps your research:
```
@inproceedings{dong2018san,
title={Style Aggregated Network for Facial Landmark Detection},
author={Dong, Xuanyi and Yan, Yan and Ouyang, Wanli and Yi, Yang},
booktitle={Computer Vision and Pattern Recognition},
year={2018},
}
```

## Contact
To ask questions or report issues, please open an issue on the [issues tracker](https://github.com/D-X-Y/SAN/issues).
123 changes: 123 additions & 0 deletions base_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
##############################################################
### Copyright (c) 2018-present, Xuanyi Dong ###
### Style Aggregated Network for Facial Landmark Detection ###
### Computer Vision and Pattern Recognition, 2018 ###
##############################################################
from __future__ import division

import os, sys, time, random, argparse, PIL
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True # please use Pillow 4.0.0 or it may fail for some images
from os import path as osp
import numbers, numpy as np
import init_path
import torch
import datasets
from shutil import copyfile
from san_vision import transforms
from utils import AverageMeter, print_log
from utils import convert_size2str, convert_secs2time, time_string, time_for_file
from visualization import draw_image_by_points, save_error_image
import debug, models, options
from sklearn.cluster import KMeans
from cluster import filter_cluster

model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))

opt = options.Options(model_names)
args = opt.opt
# Prepare options
if args.manualSeed is None: args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
torch.cuda.manual_seed_all(args.manualSeed)
torch.backends.cudnn.enabled = True
#torch.backends.cudnn.benchmark = True

def main():
# Init logger
if not os.path.isdir(args.save_path): os.makedirs(args.save_path)
log = open(os.path.join(args.save_path, 'cluster_seed_{}_{}.txt'.format(args.manualSeed, time_for_file())), 'w')
print_log('save path : {}'.format(args.save_path), log)
print_log('------------ Options -------------', log)
for k, v in sorted(vars(args).items()):
print_log('Parameter : {:20} = {:}'.format(k, v), log)
print_log('-------------- End ----------------', log)
print_log("Random Seed: {}".format(args.manualSeed), log)
print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
print_log("Pillow version : {}".format(PIL.__version__), log)
print_log("torch version : {}".format(torch.__version__), log)
print_log("cudnn version : {}".format(torch.backends.cudnn.version()), log)

# General Data Argumentation
mean_fill = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] )
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.PreCrop(args.pre_crop_expand), transforms.TrainScale2WH((args.crop_width, args.crop_height)), transforms.ToTensor(), normalize])

args.downsample = 8 # By default
args.sigma = args.sigma * args.scale_eval

data = datasets.GeneralDataset(transform, args.sigma, args.downsample, args.heatmap_type, args.dataset_name)
data.load_list(args.train_list, args.num_pts, True)
loader = torch.utils.data.DataLoader(data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

# Load all lists
all_lines = {}
for file_path in args.train_list:
listfile = open(file_path, 'r')
listdata = listfile.read().splitlines()
listfile.close()
for line in listdata:
temp = line.split(' ')
assert len(temp) == 6 or len(temp) == 7, 'This line has the wrong format : {}'.format(line)
image_path = temp[0]
all_lines[ image_path ] = line

assert args.n_clusters >= 2, 'The cluster number must be greater than 2'
resnet = models.resnet152(True).cuda()
all_features = []
for i, (inputs, target, mask, points, image_index, label_sign, ori_size) in enumerate(loader):
input_vars = torch.autograd.Variable(inputs.cuda(), volatile=True)
features, classifications = resnet(input_vars)
features = features.cpu().data.numpy()
all_features.append( features )
if i % args.print_freq == 0:
print_log('{} {}/{} extract features'.format(time_string(), i, len(loader)), log)
all_features = np.concatenate(all_features, axis=0)
kmeans_result = KMeans(n_clusters=args.n_clusters, n_jobs=args.workers).fit( all_features )
print_log('kmeans [{}] calculate done'.format(args.n_clusters), log)
labels = kmeans_result.labels_.copy()

cluster_idx = []
for iL in range(args.n_clusters):
indexes = np.where( labels == iL )[0]
cluster_idx.append( len(indexes) )
cluster_idx = np.argsort(cluster_idx)

for iL in range(args.n_clusters):
ilabel = cluster_idx[iL]
indexes = np.where( labels == ilabel )
if isinstance(indexes, tuple) or isinstance(indexes, list): indexes = indexes[0]
cluster_features = all_features[indexes,:].copy()
filtered_index = filter_cluster(indexes.copy(), cluster_features, 0.8)

print_log('{:} [{:2d} / {:2d}] has {:4d} / {:4d} -> {:4d} = {:.2f} images '.format(time_string(), iL, args.n_clusters, indexes.size, len(data), len(filtered_index), indexes.size*1./len(data)), log)
indexes = filtered_index.copy()
save_dir = osp.join(args.save_path, 'cluster-{:02d}-{:02d}'.format(iL, args.n_clusters))
save_path = save_dir + '.lst'
#if not osp.isdir(save_path): os.makedirs( save_path )
print_log('save into {}'.format(save_path), log)
txtfile = open( save_path , 'w')
for idx in indexes:
image_path = data.datas[idx]
assert image_path in all_lines, 'Not find {}'.format(image_path)
txtfile.write('{}\n'.format(all_lines[image_path]))
#basename = osp.basename( image_path )
#os.system( 'cp {} {}'.format(image_path, save_dir) )
txtfile.close()

if __name__ == '__main__':
main()
3 changes: 3 additions & 0 deletions cache_data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.lst
temp
EX300.sh
Binary file added cache_data/AFLWinfo_release.mat
Binary file not shown.
Loading

0 comments on commit e7245dd

Please sign in to comment.