-
Notifications
You must be signed in to change notification settings - Fork 284
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
545 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,3 +99,11 @@ ENV/ | |
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
# project directories | ||
datasets/ | ||
output/A/ | ||
output/B/ | ||
|
||
# model checkpoints | ||
*.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import glob | ||
import random | ||
import os | ||
|
||
from torch.utils.data import Dataset | ||
from PIL import Image | ||
import torchvision.transforms as transforms | ||
|
||
class ImageDataset(Dataset): | ||
def __init__(self, root, transforms_=None, unaligned=False, mode='train'): | ||
self.transform = transforms.Compose(transforms_) | ||
self.unaligned = unaligned | ||
|
||
self.files_A = sorted(glob.glob(os.path.join(root, '%s/A' % mode) + '/*.jpg')) | ||
self.files_B = sorted(glob.glob(os.path.join(root, '%s/B' % mode) + '/*.jpg')) | ||
|
||
def __getitem__(self, index): | ||
item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)])) | ||
|
||
if self.unaligned: | ||
item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])) | ||
else: | ||
item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)])) | ||
|
||
return {'A': item_A, 'B': item_B} | ||
|
||
def __len__(self): | ||
return max(len(self.files_A), len(self.files_B)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#!/bin/bash | ||
|
||
FILE=$1 | ||
|
||
if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then | ||
echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos" | ||
exit 1 | ||
fi | ||
|
||
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip | ||
ZIP_FILE=./datasets/$FILE.zip | ||
TARGET_DIR=./datasets/$FILE | ||
wget -N $URL -O $ZIP_FILE | ||
mkdir -p datasets | ||
unzip $ZIP_FILE -d ./datasets/ | ||
rm $ZIP_FILE | ||
|
||
# Adapt to project expected directory heriarchy | ||
mkdir -p "$TARGET_DIR/train" "$TARGET_DIR/test" | ||
mv "$TARGET_DIR/trainA" "$TARGET_DIR/train/A" | ||
mv "$TARGET_DIR/trainB" "$TARGET_DIR/train/B" | ||
mv "$TARGET_DIR/testA" "$TARGET_DIR/test/A" | ||
mv "$TARGET_DIR/testB" "$TARGET_DIR/test/B" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class ResidualBlock(nn.Module): | ||
def __init__(self, in_features): | ||
super(ResidualBlock, self).__init__() | ||
|
||
conv_block = [ nn.ReflectionPad2d(1), | ||
nn.Conv2d(in_features, in_features, 3), | ||
nn.InstanceNorm2d(in_features), | ||
nn.ReLU(inplace=True), | ||
nn.ReflectionPad2d(1), | ||
nn.Conv2d(in_features, in_features, 3), | ||
nn.InstanceNorm2d(in_features) ] | ||
|
||
self.conv_block = nn.Sequential(*conv_block) | ||
|
||
def forward(self, x): | ||
return x + self.conv_block(x) | ||
|
||
class Generator(nn.Module): | ||
def __init__(self, input_nc, output_nc, n_residual_blocks=9): | ||
super(Generator, self).__init__() | ||
|
||
# Initial convolution block | ||
model = [ nn.ReflectionPad2d(3), | ||
nn.Conv2d(input_nc, 64, 7), | ||
nn.InstanceNorm2d(64), | ||
nn.ReLU(inplace=True) ] | ||
|
||
# Downsampling | ||
in_features = 64 | ||
out_features = in_features*2 | ||
for _ in range(2): | ||
model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), | ||
nn.InstanceNorm2d(out_features), | ||
nn.ReLU(inplace=True) ] | ||
in_features = out_features | ||
out_features = in_features*2 | ||
|
||
# Residual blocks | ||
for _ in range(n_residual_blocks): | ||
model += [ResidualBlock(in_features)] | ||
|
||
# Upsampling | ||
out_features = in_features//2 | ||
for _ in range(2): | ||
model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), | ||
nn.InstanceNorm2d(out_features), | ||
nn.ReLU(inplace=True) ] | ||
in_features = out_features | ||
out_features = in_features//2 | ||
|
||
# Output layer | ||
model += [ nn.ReflectionPad2d(3), | ||
nn.Conv2d(64, output_nc, 7), | ||
nn.Tanh() ] | ||
|
||
self.model = nn.Sequential(*model) | ||
|
||
def forward(self, x): | ||
return self.model(x) | ||
|
||
class Discriminator(nn.Module): | ||
def __init__(self, input_nc): | ||
super(Discriminator, self).__init__() | ||
|
||
# A bunch of convolutions one after another | ||
model = [ nn.Conv2d(input_nc, 64, 4, stride=2, padding=1), | ||
nn.LeakyReLU(0.2, inplace=True) ] | ||
|
||
model += [ nn.Conv2d(64, 128, 4, stride=2, padding=1), | ||
nn.InstanceNorm2d(128), | ||
nn.LeakyReLU(0.2, inplace=True) ] | ||
|
||
model += [ nn.Conv2d(128, 256, 4, stride=2, padding=1), | ||
nn.InstanceNorm2d(256), | ||
nn.LeakyReLU(0.2, inplace=True) ] | ||
|
||
model += [ nn.Conv2d(256, 512, 4, padding=1), | ||
nn.InstanceNorm2d(512), | ||
nn.LeakyReLU(0.2, inplace=True) ] | ||
|
||
# FCN classification layer | ||
model += [nn.Conv2d(512, 1, 4, padding=1)] | ||
|
||
self.model = nn.Sequential(*model) | ||
|
||
def forward(self, x): | ||
x = self.model(x) | ||
# Average pooling and flatten | ||
return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#!/usr/bin/python3 | ||
|
||
import argparse | ||
import sys | ||
import os | ||
|
||
import torchvision.transforms as transforms | ||
from torchvision.utils import save_image | ||
from torch.utils.data import DataLoader | ||
from torch.autograd import Variable | ||
import torch | ||
|
||
from models import Generator | ||
from datasets import ImageDataset | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--batchSize', type=int, default=1, help='size of the batches') | ||
parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/', help='root directory of the dataset') | ||
parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data') | ||
parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data') | ||
parser.add_argument('--size', type=int, default=256, help='size of the data (squared assumed)') | ||
parser.add_argument('--cuda', action='store_true', help='use GPU computation') | ||
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') | ||
parser.add_argument('--generator_A2B', type=str, default='output/netG_A2B.pth', help='A2B generator checkpoint file') | ||
parser.add_argument('--generator_B2A', type=str, default='output/netG_B2A.pth', help='B2A generator checkpoint file') | ||
opt = parser.parse_args() | ||
print(opt) | ||
|
||
if torch.cuda.is_available() and not opt.cuda: | ||
print("WARNING: You have a CUDA device, so you should probably run with --cuda") | ||
|
||
###### Definition of variables ###### | ||
# Networks | ||
netG_A2B = Generator(opt.input_nc, opt.output_nc) | ||
netG_B2A = Generator(opt.output_nc, opt.input_nc) | ||
|
||
if opt.cuda: | ||
netG_A2B.cuda() | ||
netG_B2A.cuda() | ||
|
||
# Load state dicts | ||
netG_A2B.load_state_dict(torch.load(opt.generator_A2B)) | ||
netG_B2A.load_state_dict(torch.load(opt.generator_B2A)) | ||
|
||
# Set model's test mode | ||
netG_A2B.eval() | ||
netG_B2A.eval() | ||
|
||
# Inputs & targets memory allocation | ||
Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor | ||
input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) | ||
input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) | ||
|
||
# Dataset loader | ||
transforms_ = [ # transforms.Resize(int(opt.size*1.12), Image.BICUBIC), | ||
# transforms.RandomCrop(opt.size), | ||
# transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ] | ||
dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, mode='test'), | ||
batch_size=opt.batchSize, shuffle=False, num_workers=opt.n_cpu) | ||
################################### | ||
|
||
###### Testing###### | ||
|
||
# Create output dirs if they don't exist | ||
if not os.path.exists('output/A'): | ||
os.makedirs('output/A') | ||
if not os.path.exists('output/B'): | ||
os.makedirs('output/B') | ||
|
||
for i, batch in enumerate(dataloader): | ||
# Set model input | ||
real_A = Variable(input_A.copy_(batch['A'])) | ||
real_B = Variable(input_B.copy_(batch['B'])) | ||
|
||
# Generate output | ||
fake_B = 0.5*(netG_A2B(real_A).data + 1.0) | ||
fake_A = 0.5*(netG_B2A(real_B).data + 1.0) | ||
|
||
# Save image files | ||
save_image(fake_A, 'output/A/%04d.png' % (i+1)) | ||
save_image(fake_B, 'output/B/%04d.png' % (i+1)) | ||
|
||
sys.stdout.write('\rGenerated images %04d of %04d' % (i+1, len(dataloader))) | ||
|
||
sys.stdout.write('\n') | ||
################################### |
Oops, something went wrong.