-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
440 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
.DS_Store | ||
__pycache__ | ||
*.pyc | ||
*.py~ | ||
dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import h5py | ||
import torch | ||
import torch.utils.data as data | ||
from torchvision import datasets, transforms | ||
import os | ||
import numpy as np | ||
from PIL import Image | ||
import urllib.request | ||
import scipy.io | ||
|
||
class fixedMNIST(data.Dataset): | ||
""" Binarized MNIST dataset, proposed in | ||
http://proceedings.mlr.press/v15/larochelle11a/larochelle11a.pdf """ | ||
train_file = 'binarized_mnist_train.amat' | ||
val_file = 'binarized_mnist_valid.amat' | ||
test_file = 'binarized_mnist_test.amat' | ||
|
||
def __init__(self, root, train=True, transform=None, download=False): | ||
# we ignore transform. | ||
self.root = os.path.expanduser(root) | ||
self.train = train # training set or test set | ||
|
||
if download: self.download() | ||
if not self._check_exists(): | ||
raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') | ||
|
||
self.data = self._get_data(train=train) | ||
|
||
def __getitem__(self, index): | ||
img = self.data[index] | ||
img = Image.fromarray(img) | ||
img = transforms.ToTensor()(img).type(torch.FloatTensor) | ||
return img, torch.tensor(-1) # Meaningless tensor instead of target | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def _get_data(self, train=True): | ||
with h5py.File(os.path.join(self.root, 'data.h5'), 'r') as hf: | ||
data = hf.get('train' if train else 'test') | ||
data = np.array(data) | ||
return data | ||
|
||
def get_mean_img(self): | ||
return self.data.mean(0).flatten() | ||
|
||
def download(self): | ||
if self._check_exists(): | ||
return | ||
if not os.path.exists(self.root): | ||
os.makedirs(self.root) | ||
|
||
print('Downloading MNIST with fixed binarization...') | ||
for dataset in ['train', 'valid', 'test']: | ||
filename = 'binarized_mnist_{}.amat'.format(dataset) | ||
url = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat'.format(dataset) | ||
print('Downloading from {}...'.format(url)) | ||
local_filename = os.path.join(self.root, filename) | ||
urllib.request.urlretrieve(url, local_filename) | ||
print('Saved to {}'.format(local_filename)) | ||
|
||
def filename_to_np(filename): | ||
with open(filename) as f: | ||
lines = f.readlines() | ||
return np.array([[int(i)for i in line.split()] for line in lines]).astype('int8') | ||
|
||
train_data = np.concatenate([filename_to_np(os.path.join(self.root, self.train_file)), | ||
filename_to_np(os.path.join(self.root, self.val_file))]) | ||
test_data = filename_to_np(os.path.join(self.root, self.val_file)) | ||
with h5py.File(os.path.join(self.root, 'data.h5'), 'w') as hf: | ||
hf.create_dataset('train', data=train_data.reshape(-1, 28, 28)) | ||
hf.create_dataset('test', data=test_data.reshape(-1, 28, 28)) | ||
print('Done!') | ||
|
||
def _check_exists(self): | ||
return os.path.exists(os.path.join(self.root, 'data.h5')) | ||
|
||
|
||
class stochMNIST(datasets.MNIST): | ||
""" Gets a new stochastic binarization of MNIST at each call. """ | ||
def __getitem__(self, index): | ||
if self.train: | ||
img, target = self.train_data[index], self.train_labels[index] | ||
else: | ||
img, target = self.test_data[index], self.test_labels[index] | ||
|
||
img = Image.fromarray(img.numpy(), mode='L') | ||
img = transforms.ToTensor()(img) | ||
img = torch.bernoulli(img) # stochastically binarize | ||
return img, target | ||
|
||
def get_mean_img(self): | ||
imgs = self.train_data.type(torch.float) / 255 | ||
mean_img = imgs.mean(0).reshape(-1).numpy() | ||
return mean_img | ||
|
||
|
||
class omniglot(data.Dataset): | ||
""" omniglot dataset """ | ||
url = 'https://github.com/yburda/iwae/raw/master/datasets/OMNIGLOT/chardata.mat' | ||
|
||
def __init__(self, root, train=True, transform=None, download=False): | ||
# we ignore transform. | ||
self.root = os.path.expanduser(root) | ||
self.train = train # training set or test set | ||
|
||
if download: self.download() | ||
if not self._check_exists(): | ||
raise RuntimeError('Dataset not found. You can use download=True to download it') | ||
|
||
self.data = self._get_data(train=train) | ||
|
||
def __getitem__(self, index): | ||
img = self.data[index].reshape(28, 28) | ||
img = Image.fromarray(img) | ||
img = transforms.ToTensor()(img).type(torch.FloatTensor) | ||
img = torch.bernoulli(img) # stochastically binarize | ||
return img, torch.tensor(-1) # Meaningless tensor instead of target | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def _get_data(self, train=True): | ||
def reshape_data(data): | ||
return data.reshape((-1, 28, 28)).reshape((-1, 28*28), order='fortran') | ||
|
||
omni_raw = scipy.io.loadmat(os.path.join(self.root, 'chardata.mat')) | ||
data_str = 'data' if train else 'testdata' | ||
data = reshape_data(omni_raw[data_str].T.astype('float32')) | ||
return data | ||
|
||
def get_mean_img(self): | ||
return self.data.mean(0) | ||
|
||
def download(self): | ||
if self._check_exists(): | ||
return | ||
if not os.path.exists(self.root): | ||
os.makedirs(self.root) | ||
|
||
print('Downloading from {}...'.format(self.url)) | ||
local_filename = os.path.join(self.root, 'chardata.mat') | ||
urllib.request.urlretrieve(self.url, local_filename) | ||
print('Saved to {}'.format(local_filename)) | ||
|
||
def _check_exists(self): | ||
return os.path.exists(os.path.join(self.root, 'chardata.mat')) | ||
|
||
|
||
def data_loaders(dataset, dataset_dir, batch_size, eval_batch_size): | ||
if dataset == 'omniglot': | ||
loader_fn, root = omniglot, './dataset/omniglot' | ||
elif dataset == 'fixed_mnist': | ||
loader_fn, root = fixedMNIST, './dataset/fixedmnist' | ||
elif dataset == 'sto_mnist': | ||
loader_fn, root = stochMNIST, './dataset/stochmnist' | ||
else: | ||
raise NotImplementedError('Dataset not supported yet!') | ||
|
||
kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {} | ||
|
||
train_loader = torch.utils.data.DataLoader( | ||
loader_fn(root, train=True, download=True, transform=transforms.ToTensor()), | ||
batch_size=batch_size, shuffle=True, **kwargs) | ||
test_loader = torch.utils.data.DataLoader( # need test bs <=64 to make L_5000 tractable in one pass | ||
loader_fn(root, train=False, download=True, transform=transforms.ToTensor()), | ||
batch_size=eval_batch_size, shuffle=False, **kwargs) | ||
return train_loader, test_loader | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import torch | ||
import torch.utils.data | ||
import numpy as np | ||
import os | ||
import math | ||
|
||
from torch import nn, optim | ||
from torch.nn import functional as F | ||
from torchvision import datasets, transforms | ||
from torchvision.utils import save_image | ||
|
||
import data | ||
import models | ||
|
||
from torch.distributions.normal import Normal | ||
from torch.distributions.multivariate_normal import MultivariateNormal | ||
|
||
parser = argparse.ArgumentParser(description='Density estimation with REM') | ||
parser.add_argument('--version', type=str, default='v1', help='choice: v1, v2') | ||
parser.add_argument('--dataset', type=str, default='fixed_mnist', help='choice: fixed_mnist, sto_mnist, omniglot') | ||
parser.add_argument('--dataset_dir', type=str, default='') | ||
parser.add_argument('--batch_size', type=int, default=20) | ||
parser.add_argument('--eval_batch_size', type=int, default=64) | ||
parser.add_argument('--epochs', type=int, default=200) | ||
parser.add_argument('--seed', type=int, default=2019) | ||
parser.add_argument('--log', type=int, default=300, help='when to log training progress') | ||
parser.add_argument('--x_dim', type=int, default=784) | ||
parser.add_argument('--z_dim', type=int, default=20) | ||
parser.add_argument('--h_dim', type=int, default=200) | ||
parser.add_argument('--save_path', type=str, default='results') | ||
parser.add_argument('--n_samples_train', type=int, default=1000, help='number of importance samples at training') | ||
parser.add_argument('--n_samples_test', type=int, default=1000, help='number of importance samples at testing') | ||
parser.add_argument('--learning_rate', type=float, default=1e-3) | ||
parser.add_argument('--save_every', type=int, default=5) | ||
|
||
|
||
args = parser.parse_args() | ||
|
||
## create folder where to dump results if it doesn't exists | ||
if not os.path.exists(args.save_path): | ||
os.makedirs(args.save_path) | ||
|
||
## set seed | ||
np.random.seed(args.seed) | ||
torch.backends.cudnn.deterministic = True | ||
torch.manual_seed(args.seed) | ||
|
||
## get the right device | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
## set checkpoint file | ||
ckpt = os.path.join(args.save_path, 'REM_{}_{}_{}.pt'.format(args.version, args.dataset, args.n_samples_train)) | ||
|
||
## get data loaders | ||
train_loader, test_loader = data.data_loaders(args.dataset, args.dataset_dir, args.batch_size, args.eval_batch_size) | ||
|
||
## get model | ||
model = models.REM(args.x_dim, args.z_dim, args.h_dim, args.version) | ||
model = model.to(device) | ||
|
||
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, eps=1e-4) | ||
milestones = np.cumsum([3**i for i in range(8)]) | ||
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=10**(-1/7)) | ||
|
||
def train(epoch): | ||
model.train() | ||
train_loss = 0 | ||
for batch_idx, (data, _) in enumerate(train_loader): | ||
data = data.to(device) | ||
optimizer.zero_grad() | ||
loss = model(data, args.n_samples_train) | ||
loss.backward() | ||
train_loss += loss.item() | ||
optimizer.step() | ||
if batch_idx % args.log == 0: | ||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | ||
epoch, batch_idx * len(data), len(train_loader.dataset), | ||
100. * batch_idx / len(train_loader), loss.item() / len(data))) | ||
print('\n') | ||
print('*'*100) | ||
print('====> Epoch: {}/{} Average loss: {:.4f}'.format( | ||
epoch, args.epochs, train_loss / len(train_loader.dataset))) | ||
print('*'*100) | ||
print('\n') | ||
|
||
def test(name): | ||
loader = train_loader if name == 'train' else test_loader | ||
model.eval() | ||
with torch.no_grad(): | ||
nll = model.log_lik(loader, args.n_samples_test) | ||
print('REM_{} {} {} NLL: {}'.format(args.version, args.dataset.upper(), name.upper(), nll)) | ||
return nll | ||
|
||
train_nll_file = os.path.join(args.save_path, 'REM_{}_{}_{}_train_nll.txt'.format( | ||
args.version, args.dataset, args.n_samples_train)) | ||
test_nll_file = os.path.join(args.save_path, 'REM_{}_{}_{}_test_nll.txt'.format( | ||
args.version, args.dataset, args.n_samples_train)) | ||
|
||
for epoch in range(1, args.epochs+1): | ||
print('\n') | ||
train(epoch) | ||
|
||
train_nll = test(name='train') | ||
with open(train_nll_file, 'a') as f: | ||
f.write(str(train_nll)) | ||
f.write('\n') | ||
|
||
test_nll = test(name='test') | ||
with open(test_nll_file, 'a') as f: | ||
f.write(str(test_nll)) | ||
f.write('\n') | ||
|
||
if epoch % args.save_every == 0: | ||
torch.save(model.state_dict(), ckpt) |
Oops, something went wrong.