# Seeded inference for Table 1 metrics
GPU used: NVIDIA RTX 4000 Ada Generation (20GB)

In [None]:
from Nifty.method import *
from Nifty.networks import *
from Nifty.TO import TextureOptimization
import warnings; warnings.filterwarnings('ignore')

output_size = 512

## Nifty

In [None]:
time_nifty=0
for j in tqdm(range(1,13)):
    im_name='%s.png'%j
    path=os.path.join('./comparison/eval_base/', im_name)
    img = Tensor_load(path)
    os.makedirs('./comparison/synthesis4metrics/%d/nifty'%j, exist_ok=True)
    for i in range(10):
        torch.manual_seed(i)
        tic= time.time()
        synth= Nifty(img,rs=1,T=15,k=5,patchsize=16,stride=4,octaves=4,size=(512,512),renoise=0.5,warmup=0,memory=True,show=False,save=False)
        time_nifty+=time.time()-tic
        imsave('./comparison/synthesis4metrics/%d/nifty/%d.png'%(j,i),synth)
print('Nifty time: %.2f s'%(time_nifty/120))

100%|██████████| 12/12 [01:32<00:00,  7.73s/it]

Nifty time: 0.70 s





## Texture Optimization

In [None]:
time_kwatra=0
for j in tqdm(range(1,13)):
    im_name='%d.png'%j
    path=os.path.join('./comparison/eval_base/', im_name)
    img = Tensor_load(path)*.5+.5
    os.makedirs('./comparison/synthesis4metrics/%d/kwatra'%j, exist_ok=True)

    for i in range(10):
        torch.manual_seed(i)
        tic=time.time()
        synth,loss = TextureOptimization(img,N_subsampling=10000000, output_size=512)
        time_kwatra+=time.time()-tic
        imsave('./comparison/synthesis4metrics/%d/kwatra/%d.png'%(j,i),synth*2-1) 
print('Kwatra time: %.2f s'%(time_kwatra/120))

100%|██████████| 12/12 [04:00<00:00, 20.01s/it]

Kwatra time: 1.92 s





## U-Net approximation of the flow

In [None]:
I_actually_want_to_retrain_networks_rather_than_unziping = True

l_retrain=[]
for j in range(1,13):
    if not os.path.exists('comparison/models4metrics/%d.pth'%j):
        l_retrain.append(str(j))
if l_retrain!= []:
    print('Will train a network for images: %s'%','.join(l_retrain))


if l_retrain != [] and not I_actually_want_to_retrain_networks_rather_than_unziping:
    raise RuntimeError(
        "Networks are available in a zip file. Unzip them or set "
        "I_actually_want_to_retrain_networks_rather_than_unziping=True to retrain them."
    )

time_fm=0
train_time_fm=0
for j in tqdm(range(1,13),disable=False):
    im_name='%d.png'%j
    path=os.path.join('./comparison/eval_base/', im_name)
    img = Tensor_load(path)
    os.makedirs('./comparison/synthesis4metrics/%d/FM'%j, exist_ok=True)

    mu,sigma=img.mean(),img.std()
        
    flow_model = UNet(
        dim =16,
        dim_mults = (1, 2, 4, 4))

    if not os.path.exists('comparison/models4metrics/%d.pth'%j):
        if I_actually_want_to_retrain_networks_rather_than_unziping:
            os.makedirs('comparison/models4metrics', exist_ok=True)
            torch.manual_seed(0)
            tic=time.time()
            train_flow_net((img-mu)/sigma , flow_model,load=False,epochs=10000,show=False,save_name='comparison/models4metrics/%d.pth'%j)
            train_time_fm+=time.time()-tic
        else:
            print('Networks are available to unzip')

    flow_model.load_state_dict(torch.load('comparison/models4metrics/%d.pth'%j, map_location='cuda'),strict=False)
    flow_model.eval().cuda()


    for i in range(10):
        torch.manual_seed(i)
        tic=time.time()
        T=15
        with torch.no_grad():
            x=torch.randn(1,3,512,512).cuda()   
            times=torch.linspace(0, 1, steps=T+1).cuda()
            for it in range(T):
                t=times[it]
                t = t.to(device).unsqueeze(0)
                flow = flow_model(x,t.view(1))
                x=x+flow*(times[it+1]-times[it])
        synth_nn = x*sigma+mu
        time_fm+=time.time()-tic
        imsave('./comparison/synthesis4metrics/%d/FM/%d.png'%(j,i),synth_nn)
if l_retrain!=[]:
    print('FM training time: %d min and %d s'%(train_time_fm/len(l_retrain)//60,int(train_time_fm/len(l_retrain)%60)))
print('FM time: %.2f s'%(time_fm/120))


Will train a network for images: 1,2,3,4,5,6,7,8,9,10,11,12


  0%|          | 0/12 [00:02<?, ?it/s]


KeyboardInterrupt: 