forked from oawiles/X2Face
-
Notifications
You must be signed in to change notification settings - Fork 2
/
reconstruction.py
59 lines (45 loc) · 2.11 KB
/
reconstruction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import numpy as np
import imageio
import argparse
from VoxCelebData_withmask import FramesDataset
from UnwrappedFace import UnwrappedFaceWeightedAverage
from torch.autograd import Variable
def reconstruction_loss(a, b):
return torch.abs(a - b).mean()
def reconstruction(generator, checkpoint, log_dir, dataset, format='.gif'):
log_dir = os.path.join(log_dir, 'reconstruction')
checkpoint = torch.load(checkpoint)
generator.load_state_dict(checkpoint)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
loss_list = []
generator.eval()
for it, x in tqdm(enumerate(dataloader)):
imgs = Variable(x['video_array'], volatile=True).cuda()
apperance = imgs[:, 0]
results = []
for i in range(imgs.size()[1]):
result = generator(imgs[:, i], apperance)
results.append(result.unsqueeze(dim=1))
results = torch.cat(results, dim=1)
loss_list.append(reconstruction_loss(imgs, results).data.cpu().numpy())
results = (results.data.cpu().numpy() * 255).astype('uint8')
results = results[0].transpose((0, 2, 3, 1))
imageio.mimsave(os.path.join(log_dir, x['name'][0] + format), results)
print ("Reconstruction loss: %s" % np.mean(loss_list))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='UnwrappedFace')
parser.add_argument("--dataset", default='data/nemo', help="Path to dataset")
parser.add_argument("--folder", default="out", help="out folder")
parser.add_argument("--arch", default='unet_64', help="Network architecture")
parser.add_argument("--format", default='.gif', help="Save format")
args = parser.parse_args()
model = UnwrappedFaceWeightedAverage(output_num_channels=2, input_num_channels=3, inner_nc=512)
model = model.cuda()
dataset = FramesDataset(args.dataset, is_train=False)
reconstruction(model, os.path.join(args.folder, 'model.cpk'), args.folder, dataset, args.format)