In [1]:
import os
import sys
import time
import argparse
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

import config
import myutils
from myutils import test_images_3d, test_images_2d, test_images_outp
from loss import Loss
from torch.utils.data import DataLoader
from pdb import set_trace as bp
import cv2

import models
from utils import make_coord
from torch.autograd import Variable

# from dataset.Davis_test import get_loader

def make_coord_3d(shape, time, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        ret = ret.view(-1, ret.shape[-1])
        bias = torch.tensor(time).repeat(ret.shape[0], 1)
        ret = torch.cat([ret, bias], dim=1)
    return ret


def get_name(index):
    if index >= 0 and index <= 9:
        text = '0000' + str(index) + '.jpg'
    elif index >= 10 and index <= 99:
        text = '000' + str(index) + '.jpg'
    elif index >= 100 and index <= 999:
        text = '00' + str(index) + '.jpg'
    else:
        text = '0' + str(index) + '.jpg'
    return text


def save_image(img, index):
    if index < 1600:
        img.save(os.path.join(out_path1, get_name(index)))
    elif index < 3200:
        img.save(os.path.join(out_path2, get_name(index)))
    elif index < 4800:
        img.save(os.path.join(out_path3, get_name(index)))
    elif index < 6400:
        img.save(os.path.join(out_path4, get_name(index)))
    elif index < 8000:
        img.save(os.path.join(out_path5, get_name(index)))
    elif index < 9600:
        img.save(os.path.join(out_path6, get_name(index)))
    else:
        img.save(os.path.join(out_path7, get_name(index)))
        
device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_args = {'encoder_spec': {'name': 'edsr-baseline', 'args': {'no_upsampling': True}}, 'imnet_spec': {'name': 'mlp', 'args': {'out_dim': 3, 'hidden_list': [64, 64]}}}
model_spec = {'name': 'liif_optical', 'args': model_args}
model = models.make(model_spec).to(device)
model = nn.DataParallel(model, device_ids=device_ids)
name = 'temp-109-25-499.pth'
model.load_state_dict(torch.load('/model/nnice1216/video/' + name))

# imgpath = '/data/nnice1216/vimeo_septuplet/DAVIS/JPEGImages/Full-Resolution/insects-original/'
imgpath = '/data/nnice1216/vimeo_septuplet/DAVIS/JPEGImages/Full-Resolution/insects-sampled/'
model.eval()
epoch = 0
# time = 0.5
iter_id = 1
signal = 0
with torch.no_grad():
    name_list = sorted(os.listdir(imgpath))
    # name_list = sorted(name_list)[20:140:2]
    index = 0
    out_path1 = '/output/Imagestry11/'
    out_path2 = '/output/Imagestry22/'
    out_path3 = '/output/Imagestry33/'
    out_path4 = '/output/Imagestry44/'
    out_path5 = '/output/Imagestry55/'
    out_path6 = '/output/Imagestry66/'
    out_path7 = '/output/Imagestry77/'
    if not os.path.exists(out_path1):
        os.makedirs(out_path1)
    if not os.path.exists(out_path2):
        os.makedirs(out_path2)
    if not os.path.exists(out_path3):
        os.makedirs(out_path3)
    if not os.path.exists(out_path4):
        os.makedirs(out_path4)
    if not os.path.exists(out_path5):
        os.makedirs(out_path5)
    if not os.path.exists(out_path6):
        os.makedirs(out_path6)
    if not os.path.exists(out_path7):
        os.makedirs(out_path7)
        
            
    for image_id in range(len(name_list) - 3):

        start = time.time()
        imgpaths = [name_list[i] for i in range(image_id, image_id + 4)]
        pth_ = imgpaths

        images = [Image.open(os.path.join(imgpath, pth)) for pth in imgpaths]
        h, w = images[0].size
        # print(h, w)
        images = [img.resize((960, 544)) for img in images]
        save_image(images[1], index)
        print("INDEX {}, DONE!".format(index))
        # images[1].save(os.path.join(out_path, get_name(index)))
        index += 1
        print("START: ", image_id, index)

        T = transforms.ToTensor()
        # images = [((T(img_) - 0.5) * 2)[None] for img_ in images]
        images = [T(img_)[None] for img_ in images]
        h, w = images[0].shape[2], images[0].shape[3]
        print(h, w)
        coord0 = make_coord_3d((h, w), 0.0)
        coord1 = make_coord_3d((h, w), 1 / 8)
        coord2 = make_coord_3d((h, w), 2 / 8)
        coord3 = make_coord_3d((h, w), 3 / 8)
        coord4 = make_coord_3d((h, w), 4 / 8)
        coord5 = make_coord_3d((h, w), 5 / 8)
        coord6 = make_coord_3d((h, w), 6 / 8)
        coord7 = make_coord_3d((h, w), 7 / 8)
        coords = [coord0, coord1, coord2, coord3, coord4, coord5, coord6, coord7]
        cell0, cell1, cell2, cell3, cell4, cell5, cell6, cell7, cell8 = torch.ones_like(coord1), torch.ones_like(coord1), torch.ones_like(coord1), torch.ones_like(coord1), torch.ones_like(coord1), torch.ones_like(coord1), torch.ones_like(coord1), torch.ones_like(coord1), torch.ones_like(coord1)
        cell0[:, 0] *= 2 / h
        cell0[:, 1] *= 2 / w
        cell1[:, 0] *= 2 / h
        cell1[:, 1] *= 2 / w
        cell2[:, 0] *= 2 / h
        cell2[:, 1] *= 2 / w
        cell3[:, 0] *= 2 / h
        cell3[:, 1] *= 2 / w
        cell4[:, 0] *= 2 / h
        cell4[:, 1] *= 2 / w
        cell5[:, 0] *= 2 / h
        cell5[:, 1] *= 2 / w
        cell6[:, 0] *= 2 / h
        cell6[:, 1] *= 2 / w
        cell7[:, 0] *= 2 / h
        cell7[:, 1] *= 2 / w
        cell8[:, 0] *= 2 / h
        cell8[:, 1] *= 2 / w

        cell0[:, 2] *= 0.
        cell1[:, 2] *= 1 / 8
        cell2[:, 2] *= 2 / 8 
        cell3[:, 2] *= 3 / 8 
        cell4[:, 2] *= 4 / 8 
        cell5[:, 2] *= 5 / 8 
        cell6[:, 2] *= 6 / 8
        cell7[:, 2] *= 7 / 8 
        cell8[:, 2] *= 1.
        cells = [cell0, cell1, cell2, cell3, cell4, cell5, cell6, cell7]

        coords = [c_[None].to(device) for c_ in coords]
        cells = [ce_[None].to(device) for ce_ in cells]

        images = [img_.to(device) for img_ in images]
        images = torch.stack(images, dim=2)
        
        torch.cuda.synchronize()
        
        outs = model(images, coords, 8, index)

        torch.cuda.synchronize()
        index += 62
        # for out in outs:
        out_temp = (outs).clamp(0, 1).permute(0, 2, 3, 1)[0].cpu()
        save_image(Image.fromarray((out_temp.numpy() * 255).astype(np.uint8)), index)
        print("INDEX {}, DONE!".format(index))
        index += 1
        end = time.time()
        print("Epoch {} End, Index {}, Cost time {}".format(image_id, index, end - start))
        
        del outs, out_temp
        torch.cuda.empty_cache()
        
        if index > 1600 and signal == 0: 
            fps = 60
            size = (960, 544)
            path = '/output/Imagestry11'
            video = cv2.VideoWriter("/output/Video2234.avi", cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, size)

            for i, name in enumerate(sorted(os.listdir(path))):
                if i % 50 == 0:
                    print(name)
                img = cv2.imread(os.path.join(path, name))
                video.write(img)

            video.release()
            signal = 1

BATCHNORM:  False
INDEX 0, DONE!
START:  0 1
544 960
---Testing--- 5
INDEX 1, DONE!
---Testing--- 5
INDEX 2, DONE!
---Testing--- 5
INDEX 3, DONE!
---Testing--- 5
INDEX 4, DONE!
---Testing--- 7
INDEX 5, DONE!
---Testing--- 9
INDEX 6, DONE!
---Testing--- 11
INDEX 7, DONE!
---Testing--- 11
INDEX 8, DONE!
---Testing--- 13
INDEX 9, DONE!
---Testing--- 15
INDEX 10, DONE!
---Testing--- 15
INDEX 11, DONE!
---Testing--- 17
INDEX 12, DONE!
---Testing--- 19
INDEX 13, DONE!
---Testing--- 21
INDEX 14, DONE!
---Testing--- 21
INDEX 15, DONE!
---Testing--- 23
INDEX 16, DONE!
---Testing--- 25
INDEX 17, DONE!
---Testing--- 25
INDEX 18, DONE!
---Testing--- 27
INDEX 19, DONE!
---Testing--- 29
INDEX 20, DONE!
---Testing--- 31
INDEX 21, DONE!
---Testing--- 31
INDEX 22, DONE!
---Testing--- 33
INDEX 23, DONE!
---Testing--- 35
INDEX 24, DONE!
---Testing--- 37
INDEX 25, DONE!
---Testing--- 37
INDEX 26, DONE!
---Testing--- 39
INDEX 27, DONE!
---Testing--- 41
INDEX 28, DONE!
---Testing--- 41
INDEX 29, DONE!
---Te

KeyboardInterrupt: 

In [None]:
import torchvision.transforms as tfs
from PIL import Image
import torch
from model.FLAVR_arch import UNet_3D_3D
import config
import numpy as np
import models
from utils import make_coord
import torch.nn as nn


def make_coord_2d(shape, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        ret = ret.view(-1, ret.shape[-1])
    return ret



'''
args, unparsed = config.get_args()
device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Building model: %s"%args.model.lower())
# model = UNet_3D_3D(args.model.lower() , n_inputs=args.nbr_frame, n_outputs=args.n_outputs, joinType=args.joinType).to(device)
model = UNet_3D_3D('unet_18', n_inputs=4, n_outputs=1, joinType='concat').to(device)
print(args.model.lower(), args.nbr_frame, args.n_outputs, args.joinType)
model = torch.nn.DataParallel(model, device_ids=device_ids)
model_dict = model.state_dict()
model.load_state_dict(torch.load(args.load_from)["state_dict"] , strict=True)
print(args.load_from)
'''
device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_args = {'encoder_spec': {'name': 'edsr-baseline', 'args': {'no_upsampling': True}}, 'imnet_spec': {'name': 'mlp', 'args': {'out_dim': 3, 'hidden_list': [256, 256, 256, 256]}}}
model_spec = {'name': 'liif3d_flavr2', 'args': model_args}
model = models.make(model_spec).to(device)
model = nn.DataParallel(model, device_ids=device_ids)
# model.load_state_dict(torch.load('/output/vimeo_iter1999.pth'))
# model.load_state_dict(torch.load('/code/FLAVR-main/vimeo_iter2999.pth'))

imgpath = '/data/nnice1216/vimeo_septuplet/DAVIS/JPEGImages/Full-Resolution/dolphins-show/'

#imgpaths = [imgpath + f'/im{i}.png' for i in range(1,8)]
imgpaths = [imgpath + f'/0000{i}.jpg' for i in range(1,8)]
pth_ = imgpaths

images = [Image.open(pth) for pth in imgpaths]
h, w = images[0].size
images = [img.resize((360, 208)) for img in images]
inputs = [int(e)-1 for e in list('1234')]
inputs = inputs[:len(inputs)//2] + inputs[len(inputs)//2:]
images = [images[i] for i in inputs]
imgpaths = [imgpaths[i] for i in inputs]

T = tfs.ToTensor()
images = [T(img_)[None] for img_ in images]

gt = images[len(images)//2]
images = images[:len(images)//2] + images[len(images)//2:]
images, gt_image = images, [gt]
h, w = gt[0].shape[1], gt.shape[2]
coord = make_coord_2d((h, w))
cell = torch.ones_like(coord)
cell[:, 0] *= 2 / h
cell[:, 1] *= 2 / w

In [2]:
from model.FLAVR_arch import UNet_3D_3D

import os
import sys
import time
import argparse
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

import config
import myutils
from myutils import test_images_3d, test_images_2d, test_images_outp
from loss import Loss
from torch.utils.data import DataLoader
from pdb import set_trace as bp

import models
from utils import make_coord
from torch.autograd import Variable

def get_name(index):
    if index >= 0 and index <= 9:
        text = '0000' + str(index) + '.jpg'
    elif index >= 10 and index <= 99:
        text = '000' + str(index) + '.jpg'
    elif index >= 100 and index <= 999:
        text = '00' + str(index) + '.jpg'
    else:
        text = '0' + str(index) + '.jpg'
    return text

def save_image(img, index):
    if index < 1600:
        img.save(os.path.join(out_path1, get_name(index)))
    elif index < 3200:
        img.save(os.path.join(out_path2, get_name(index)))
    elif index < 4800:
        img.save(os.path.join(out_path3, get_name(index)))
    elif index < 6400:
        img.save(os.path.join(out_path4, get_name(index)))
    elif index < 8000:
        img.save(os.path.join(out_path5, get_name(index)))
    elif index < 9600:
        img.save(os.path.join(out_path6, get_name(index)))
    else:
        img.save(os.path.join(out_path7, get_name(index)))


args, unparsed = config.get_args()

device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print("Building model: %s"%args.model.lower())
model = UNet_3D_3D(args.model.lower() , n_inputs=args.nbr_frame, n_outputs=args.n_outputs, joinType=args.joinType)
model = torch.nn.DataParallel(model).to(device)
model_dict = model.state_dict()
model.load_state_dict(torch.load(args.load_from)["state_dict"] , strict=True)

with torch.no_grad():
    out_path1 = '/output/Imagestry1/'
    out_path2 = '/output/Imagestry2/'
    out_path3 = '/output/Imagestry3/'
    out_path4 = '/output/Imagestry4/'
    out_path5 = '/output/Imagestry5/'
    out_path6 = '/output/Imagestry6/'
    out_path7 = '/output/Imagestry7/'
    if not os.path.exists(out_path1):
        os.makedirs(out_path1)
    if not os.path.exists(out_path2):
        os.makedirs(out_path2)
    if not os.path.exists(out_path3):
        os.makedirs(out_path3)
    if not os.path.exists(out_path4):
        os.makedirs(out_path4)
    if not os.path.exists(out_path5):
        os.makedirs(out_path5)
    if not os.path.exists(out_path6):
        os.makedirs(out_path6)
    if not os.path.exists(out_path7):
        os.makedirs(out_path7)
        
    # imgpath = '/data/nnice1216/vimeo_septuplet/DAVIS/JPEGImages/Full-Resolution/dolphins-show/'
    imgpath = '/data/nnice1216/vimeo_septuplet/DAVIS/highFPS/dolphins-flavr-8x/'
    name_list = sorted(os.listdir(imgpath))
    # name_list = sorted(name_list)[20:140:2]
    index = 0
    for image_id in range(len(name_list) - 3):
        out_path = '/output/lowres-dolphins-8x/'
        if not os.path.exists(out_path):
            os.makedirs(out_path)

        imgpaths = [name_list[i] for i in range(image_id, image_id + 4)]
        pth_ = imgpaths

        images = [Image.open(os.path.join(imgpath, pth)) for pth in imgpaths]
        h, w = images[0].size
        # print(h, w)
        # images = [img.resize((960, 544)) for img in images]
        save_image(images[1], index)
        # images[1].save(os.path.join(out_path, get_name(index)))
        index += 1
        print("START: ", image_id, index)

        T = transforms.ToTensor()
        # images = [((T(img_) - 0.5) * 2)[None] for img_ in images]
        images = [T(img_)[None] for img_ in images]
        h, w = images[0].shape[2], images[0].shape[3]

        images = [img_.to(device) for img_ in images]

        torch.cuda.synchronize()

        out = model(images)
        
        torch.cuda.synchronize()
        for i in range(len(out)):
            out_temp = (out[i]).clamp(0, 1)[0].permute(1, 2, 0).cpu()
            save_image(Image.fromarray((out_temp.numpy() * 255).astype(np.uint8)), index)
            index += 1
            print(index)

Unparsed args: ['-f', '/root/.local/share/jupyter/runtime/kernel-11746859-5a0a-4b93-be9d-7f02753598f1.json']
Building model: unet_18
START:  0 1
2
3
4
5
6
7
8
START:  1 9
10
11
12
13
14
15
16
START:  2 17
18
19
20
21
22
23
24
START:  3 25
26
27
28
29
30
31
32
START:  4 33
34
35
36
37
38
39
40
START:  5 41
42
43
44
45
46
47
48
START:  6 49
50
51
52
53
54
55
56
START:  7 57
58
59
60
61
62
63
64
START:  8 65
66
67
68
69
70
71
72
START:  9 73
74
75
76
77
78
79
80
START:  10 81
82
83
84
85
86
87
88
START:  11 89
90
91
92
93
94
95
96
START:  12 97
98
99
100
101
102
103
104
START:  13 105
106
107
108
109
110
111
112
START:  14 113
114
115
116
117
118
119
120
START:  15 121
122
123
124
125
126
127
128
START:  16 129
130
131
132
133
134
135
136
START:  17 137
138
139
140
141
142
143
144
START:  18 145
146
147
148
149
150
151
152
START:  19 153
154
155
156
157
158
159
160
START:  20 161
162
163
164
165
166
167
168
START:  21 169
170
171
172
173
174
175
176
START:  22 177
178
179
180
181
182
183


In [14]:
def get_name(index):
    if index >= 0 and index <= 9:
        text = '0000' + str(index)
    elif index >= 10 and index <= 99:
        text = '000' + str(index)
    elif index >= 100 and index <= 999:
        text = '00' + str(index)
    else:
        text = '0' + str(index)
    return text
get_name(10)

'00010'

In [3]:
from PIL import Image
import os

path = '/data/nnice1216/vimeo_septuplet/DAVIS/JPEGImages/Full-Resolution/dolphins-show/'
out_path2 = '/data/nnice1216/vimeo_septuplet/DAVIS/JPEGImages/Full-Resolution/dolphins-show-lowres4x/'
if not os.path.exists(out_path2):
    os.makedirs(out_path2)
    
for i, name in enumerate(sorted(os.listdir(path))):
    if i % 50 == 0:
        print(name)
    img = Image.open(os.path.join(path, name))
    h, w = img.size
    img = img.resize((h // 4, w // 4), resample=Image.BICUBIC)
    img.save(os.path.join(out_path2, name))

video.release()

00000.jpg
00050.jpg


In [8]:
import cv2
import numpy as np
from PIL import Image
import os

fps = 60
size = (1280, 720)
path1 = '/output/Imagestry1/'
path2 = '/output/Imagestry2/'
path3 = '/output/Imagestry3/'
out_path2 = '/output/all_images1/'
if not os.path.exists(out_path2):
    os.makedirs(out_path2)
name_list = sorted(os.listdir(path1)) + sorted(os.listdir(path2)) + sorted(os.listdir(path3))
video = cv2.VideoWriter("/output/Video3.avi", cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, size)

for i, name in enumerate(sorted(os.listdir(path3))):
    img = Image.open(os.path.join(path3, name))
    img = img.resize((960, 544), resample=Image.BICUBIC)
    img.save(os.path.join(out_path2, name))
    if i % 50 == 0:
        print(name)

# video.release()

03200.jpg
03250.jpg
03300.jpg
03350.jpg
03400.jpg
03450.jpg
03500.jpg
03550.jpg
03600.jpg
03650.jpg
03700.jpg
03750.jpg
03800.jpg
03850.jpg
03900.jpg
03950.jpg
04000.jpg
04050.jpg
04100.jpg
04150.jpg
04200.jpg
04250.jpg
04300.jpg
04350.jpg
04400.jpg
04450.jpg
04500.jpg


In [4]:
imgpath = '/data/nnice1216/vimeo_septuplet/DAVIS/highFPS/bmi-rider-flavr-8x/'
name_list = os.listdir(imgpath)
sorted(name_list)

['00000.jpg',
 '00001.jpg',
 '00002.jpg',
 '00003.jpg',
 '00004.jpg',
 '00005.jpg',
 '00006.jpg',
 '00007.jpg',
 '00008.jpg',
 '00009.jpg',
 '00010.jpg',
 '00011.jpg',
 '00012.jpg',
 '00013.jpg',
 '00014.jpg',
 '00015.jpg',
 '00016.jpg',
 '00017.jpg',
 '00018.jpg',
 '00019.jpg',
 '00020.jpg',
 '00021.jpg',
 '00022.jpg',
 '00023.jpg',
 '00024.jpg',
 '00025.jpg',
 '00026.jpg',
 '00027.jpg',
 '00028.jpg',
 '00029.jpg',
 '00030.jpg',
 '00031.jpg',
 '00032.jpg',
 '00033.jpg',
 '00034.jpg',
 '00035.jpg',
 '00036.jpg',
 '00037.jpg',
 '00038.jpg',
 '00039.jpg',
 '00040.jpg',
 '00041.jpg',
 '00042.jpg',
 '00043.jpg',
 '00044.jpg',
 '00045.jpg',
 '00046.jpg',
 '00047.jpg',
 '00048.jpg',
 '00049.jpg',
 '00050.jpg',
 '00051.jpg',
 '00052.jpg',
 '00053.jpg',
 '00054.jpg',
 '00055.jpg',
 '00056.jpg',
 '00057.jpg',
 '00058.jpg',
 '00059.jpg',
 '00060.jpg',
 '00061.jpg',
 '00062.jpg',
 '00063.jpg',
 '00064.jpg',
 '00065.jpg',
 '00066.jpg',
 '00067.jpg',
 '00068.jpg',
 '00069.jpg',
 '00070.jpg',
 '0007

In [2]:
import time

In [3]:
t = time.localtime()

In [2]:
import os
import sys
import time
import argparse
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

import config
import myutils
from myutils import test_images_3d, test_images_2d, test_images_outp
from loss import Loss
from torch.utils.data import DataLoader
from pdb import set_trace as bp

import models
from utils import make_coord
from torch.autograd import Variable

# from dataset.Davis_test import get_loader

def make_coord_3d(shape, time, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        ret = ret.view(-1, ret.shape[-1])
        bias = torch.tensor(time).repeat(ret.shape[0], 1)
        ret = torch.cat([ret, bias], dim=1)
    return ret


def get_name(index):
    if index >= 0 and index <= 9:
        text = '0000' + str(index) + '.jpg'
    elif index >= 10 and index <= 99:
        text = '000' + str(index) + '.jpg'
    elif index >= 100 and index <= 999:
        text = '00' + str(index) + '.jpg'
    else:
        text = '0' + str(index) + '.jpg'
    return text


def save_image(img, index):
    if index < 1600:
        img.save(os.path.join(out_path1, get_name(index)))
    elif index < 3200:
        img.save(os.path.join(out_path2, get_name(index)))
    elif index < 4800:
        img.save(os.path.join(out_path3, get_name(index)))
    elif index < 6400:
        img.save(os.path.join(out_path4, get_name(index)))
    elif index < 8000:
        img.save(os.path.join(out_path5, get_name(index)))
    elif index < 9600:
        img.save(os.path.join(out_path6, get_name(index)))
    else:
        img.save(os.path.join(out_path7, get_name(index)))



device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_args = {'encoder_spec': {'name': 'edsr-baseline', 'args': {'no_upsampling': True}}, 'imnet_spec': {'name': 'mlp', 'args': {'out_dim': 3, 'hidden_list': [64, 64]}}}
model_spec = {'name': 'liif_bidi', 'args': model_args}
model = models.make(model_spec).to(device)
model = nn.DataParallel(model, device_ids=device_ids)
name = '633-best-21-22.pth'
model.load_state_dict(torch.load('/model/nnice1216/video/' + name))

imgpath = '/data/nnice1216/vimeo_septuplet/DAVIS/JPEGImages/Full-Resolution/insects-original/'
model.eval()
epoch = 0
# time = 0.5
iter_id = 1
with torch.no_grad():
    name_list = os.listdir(imgpath)
    name_list = sorted(name_list)[20:140:2]
    index = 0
    out_path1 = '/output/Imagestry111/'
    out_path2 = '/output/Imagestry2/'
    out_path3 = '/output/Imagestry3/'
    out_path4 = '/output/Imagestry4/'
    out_path5 = '/output/Imagestry5/'
    out_path6 = '/output/Imagestry6/'
    out_path7 = '/output/Imagestry7/'
    if not os.path.exists(out_path1):
        os.makedirs(out_path1)
    if not os.path.exists(out_path2):
        os.makedirs(out_path2)
    if not os.path.exists(out_path3):
        os.makedirs(out_path3)
    if not os.path.exists(out_path4):
        os.makedirs(out_path4)
    if not os.path.exists(out_path5):
        os.makedirs(out_path5)
    if not os.path.exists(out_path6):
        os.makedirs(out_path6)
    if not os.path.exists(out_path7):
        os.makedirs(out_path7)
        
            
    for image_id in range(len(name_list) - 3):

        start = time.time()
        imgpaths = [name_list[i] for i in range(image_id, image_id + 4)]
        pth_ = imgpaths

        images = [Image.open(os.path.join(imgpath, pth)) for pth in imgpaths]
        h, w = images[0].size
        # print(h, w)
        images = [img.resize((960, 544)) for img in images]
        save_image(images[1], index)
        print("INDEX {}, DONE!".format(index))
        # images[1].save(os.path.join(out_path, get_name(index)))
        index += 1

BATCHNORM:  False
INDEX 0, DONE!
INDEX 1, DONE!
INDEX 2, DONE!
INDEX 3, DONE!
INDEX 4, DONE!
INDEX 5, DONE!
INDEX 6, DONE!
INDEX 7, DONE!
INDEX 8, DONE!
INDEX 9, DONE!
INDEX 10, DONE!
INDEX 11, DONE!
INDEX 12, DONE!
INDEX 13, DONE!
INDEX 14, DONE!
INDEX 15, DONE!
INDEX 16, DONE!
INDEX 17, DONE!
INDEX 18, DONE!
INDEX 19, DONE!
INDEX 20, DONE!
INDEX 21, DONE!
INDEX 22, DONE!
INDEX 23, DONE!
INDEX 24, DONE!
INDEX 25, DONE!
INDEX 26, DONE!
INDEX 27, DONE!
INDEX 28, DONE!
INDEX 29, DONE!
INDEX 30, DONE!
INDEX 31, DONE!
INDEX 32, DONE!
INDEX 33, DONE!
INDEX 34, DONE!
INDEX 35, DONE!
INDEX 36, DONE!
INDEX 37, DONE!
INDEX 38, DONE!
INDEX 39, DONE!
INDEX 40, DONE!
INDEX 41, DONE!
INDEX 42, DONE!
INDEX 43, DONE!
INDEX 44, DONE!
INDEX 45, DONE!
INDEX 46, DONE!
INDEX 47, DONE!
INDEX 48, DONE!
INDEX 49, DONE!
INDEX 50, DONE!
INDEX 51, DONE!
INDEX 52, DONE!
INDEX 53, DONE!
INDEX 54, DONE!
INDEX 55, DONE!
INDEX 56, DONE!
