In [64]:
%reload_ext autoreload
%autoreload 2

from build_dataset import create_dataloader
from main import get_args

import torch
import cv2
import numpy as np
import os
import json
from PIL import Image

from glob import glob
from transformers import GPT2Tokenizer

from tqdm.auto import tqdm

In [2]:
args = get_args(True)
args.yaml_file = '../config/captioning_config.yaml'

In [3]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'cls_token':'[CLS]'})

1

In [4]:
args.batch_size = 2
args.num_workers = 4
args.use_saved_frame = False

trainloader = create_dataloader(args, 'train', tokenizer)
validloader = create_dataloader(args, 'val', tokenizer)
testloader = create_dataloader(args, 'test', tokenizer)

# 1. 비디오 존재 확인

In [25]:
def check_dataloader(dataloader):
    data_video_info = dataloader.dataset.video_info

    error_data_video_ids = []
    for i in range(len(data_video_info)):
        if not os.path.isfile(data_video_info.iloc[i]['video_path']):
            error_data_video_ids.append(data_video_info.iloc[i]['video_id'])
        
    return error_data_video_ids

In [26]:
error_train_video_ids = check_dataloader(trainloader)
error_valid_video_ids = check_dataloader(validloader)
error_test_video_ids = check_dataloader(testloader)

# 2. opencv 문제확인

In [91]:
def check_capture(dataloader):
    error_i = []
    for i in range(len(dataloader.dataset.video_info)):
        try:
            print(i, end='\r')
            # read video
            cap = cv2.VideoCapture(dataloader.dataset.video_info.iloc[i]['video_path'])

            # extract time list of video
            time_list = dataloader.dataset.extract_time_list(cap)
        except:
            error_i.append(i)
    
    return error_i

In [92]:
train_error_i = check_capture(trainloader)

8267

In [93]:
val_error_i = check_capture(validloader)

2080

In [94]:
test_error_i = check_capture(testloader)

2081

In [95]:
print('len(train_error_i): ',len(train_error_i))
print('len(valid_error_i): ',len(val_error_i))
print('len(test_error_i): ',len(test_error_i))

len(train_error_i):  0
len(valid_error_i):  0
len(test_error_i):  0


# 3. 비디오 저장

In [5]:
# del trainloader.dataset.annotation['CcPrDDRuHrg']
# del trainloader.dataset.annotation['kvaefb9jAHE']

# json.dump(trainloader.dataset.annotation,
#           open('../datasets/annotations/trainset_highest_f1_removed2video0520.json','w'))

In [4]:
args.batch_size = 2
args.num_workers = 4
args.use_saved_frame = False
trainloader = create_dataloader(args, 'train', tokenizer, test_mode=True)
validloader = create_dataloader(args, 'val', tokenizer, test_mode=True)
testloader = create_dataloader(args, 'test', tokenizer, test_mode=True)

In [5]:
def save_frames(dataloader, split, frames_dir):
    frames_dir = os.path.join(frames_dir, split)
    os.makedirs(frames_dir, exist_ok=True)
    
    error_boundary_id = []
    boundary_list = dataloader.dataset.boundary_list
    saved_list = os.listdir(frames_dir)
    
    for idx, boundary_i in enumerate(boundary_list):
        print(f'{idx+1} / {len(boundary_list)}',end='\r')
        boundary_id = boundary_i['boundary_id']
        if f'{boundary_id}.pt' in saved_list:
            continue 
    
        try:
            if not f'{boundary_id}.pt' in os.listdir(frames_dir):
                torch.save(
                    dict([dataloader.dataset[idx]]),
                    os.path.join(frames_dir, f'{boundary_id}.pt')
                )
        except:
            error_boundary_id.append(boundary_id)
        
    return error_boundary_id

In [6]:
frames_dir = '/datasets/GEBC/frames'
save_frames(trainloader, 'train', frames_dir)

26623 / 26623

[]

In [7]:
frames_dir = '/datasets/GEBC/frames'

save_frames(validloader, 'val', frames_dir)

6748 / 6748

[]

In [8]:
frames_dir = '/datasets/GEBC/frames'

save_frames(testloader, 'test', frames_dir)

6619 / 6619

[]

In [9]:
train_b = [f"{b['boundary_id']}.pt" for b in trainloader.dataset.boundary_list]
val_b = [f"{b['boundary_id']}.pt" for b in validloader.dataset.boundary_list]
test_b = [f"{b['boundary_id']}.pt" for b in testloader.dataset.boundary_list]

print('len(train_b): ',len(train_b))
print('len(val_b): ',len(val_b))
print('len(test_b): ',len(test_b))

len(train_b):  26623
len(val_b):  6748
len(test_b):  6619


In [10]:
frames_dir = '/datasets/GEBC/frames'
train_save = os.listdir(os.path.join(frames_dir,'train'))
val_save = os.listdir(os.path.join(frames_dir,'val'))
test_save = os.listdir(os.path.join(frames_dir,'test'))

print('len(train_save): ',len(train_save))
print('len(val_save): ',len(val_save))
print('len(test_save): ',len(test_save))

len(train_save):  26623
len(val_save):  6748
len(test_save):  6619


In [14]:
print('set(train_save) - set(train_b): ',set(train_save) - set(train_b))
print('set(train_b) - set(train_save): ',set(train_b) - set(train_save))
print('set(val_save) - set(val_b): ',set(val_save) - set(val_b))
print('set(val_b) - set(val_save): ',set(val_b) - set(val_save))
print('set(test_save) - set(test_b): ',set(test_save) - set(test_b))
print('set(test_b) - set(test_save): ',set(test_b) - set(test_save))

set(train_save) - set(train_b):  set()
set(train_b) - set(train_save):  set()
set(val_save) - set(val_b):  set()
set(val_b) - set(val_save):  set()
set(test_save) - set(test_b):  set()
set(test_b) - set(test_save):  set()


# 확인

In [17]:
args.batch_size = 2
args.num_workers = 4
args.use_saved_frame = True
trainloader = create_dataloader(args, 'train', tokenizer, test_mode=False)
validloader = create_dataloader(args, 'val', tokenizer, test_mode=False)
testloader = create_dataloader(args, 'test', tokenizer, test_mode=False)

In [18]:
boundary_id, captions, frames, labels = next(iter(trainloader))

In [24]:
for k in frames.keys():
    print(f'{k}.size(): {frames[k].shape}')

boundary.size(): torch.Size([2, 3, 224, 224])
before.size(): torch.Size([2, 10, 3, 224, 224])
after.size(): torch.Size([2, 10, 3, 224, 224])


In [70]:
import matplotlib.pyplot as plt
import os
from PIL import Image
from IPython.display import Image as Img
from IPython.display import display
def generate_gif(img_list):    
    images = [Image.fromarray(x) for x in img_list]

    im = images[0]
    im.save('out.gif', save_all=True, append_images=images[1:],loop=0xff, duration=500)
    # loop 반복 횟수
    # duration 프레임 전환 속도 (500 = 0.5초)
    return Img(url='out.gif')

In [86]:
tokenizer.decode(captions['input_ids'][0])

'Subject: girl in blue and white bodysuit //Status_Before: jumping on the sponge //Status_After: jump and back flip on the sponge<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>

In [82]:
imgs = torch.cat([
    frames['before'][0], 
    frames['boundary'][0:1], 
    frames['after'][0]
]).permute(0,2,3,1).numpy()

imgs = [((img - img.min()) / (img.max() - img.min())*255).astype(np.uint8) for img in imgs]

In [84]:
generate_gif(imgs)