In [None]:
from algorithms.feature_extraction_loading import extract_diffusion_features

extract_diffusion_features({"davis": "tapvid_davis/tapvid_davis.pkl"}, enable_vae_slicing=False)

In [None]:
import pickle
from algorithms.utils import read_config_file, feature_collate_fn
from algorithms.feature_extraction_loading import FeatureDataset
from torch.utils.data import DataLoader

config = read_config_file("configs/config.yaml")

dataset = FeatureDataset(feature_dataset_path=config['dataset_dir'])
dataloader = DataLoader(dataset, collate_fn=feature_collate_fn)

In [None]:
# Create feature tensor for swan
import torch
import torch.nn.functional as func


def combine_feaures(feature_dict):
    f = feature_dict

    feature = f["mid_block"][0].float() + f["down_block"][2].float() + f["down_block"][3].float() #1280x4x4

    feature = func.interpolate(feature, scale_factor=2) +  f["up_block"][0].float() #1280x8x8

    feature = func.interpolate(feature, scale_factor=2) +  f["up_block"][1].float() #1280x16x16


    F,C,H,W = feature.shape
    feature = feature.permute((0, 2, 3, 1)).contiguous().view(F*H,W,C)
    feature = func.interpolate(feature, size=640)
    feature = feature.view(F,H,W,-1).permute(0,3,1,2) #640x16x16

    feature = feature +  func.interpolate(f["down_block"][1].float(), scale_factor=2) #640x16x16

    F,C,H,W = feature.shape
    feature = feature.permute((0, 2, 3, 1)).contiguous().view(F*H,W,C)
    feature = func.interpolate(feature, size=320)
    feature = feature.view(F,H,W,320).permute(0,3,1,2) #320x16x16

    feature = feature + f["down_block"][0].float()

    a1 = f["encoder_block"][2].float() + f["encoder_block"][3].float()
    F,C,H,W = a1.shape
    a1 = a1.permute((0, 2, 3, 1)).contiguous().view(F*H,W,C)
    a1 = func.interpolate(a1, size=320)
    a1 = a1.view(F,H,W,320).permute(0,3,1,2) #320x32x32

    a2 = f["up_block"][2].float()
    F,C,H,W = a2.shape
    a2 = a2.permute((0, 2, 3, 1)).contiguous().view(F*H,W,C)
    a2 = func.interpolate(a2, size=320)
    a2 = a2.view(F,H,W,320).permute(0,3,1,2) #320x32x32

    feature = func.interpolate(feature, scale_factor=2) + a1 + a2 #320x32x32

    F,C,H,W = feature.shape
    feature = feature.permute((0, 2, 3, 1)).contiguous().view(F*H,W,C)
    feature = func.interpolate(feature, size=256)
    feature = feature.view(F,H,W,256).permute(0,3,1,2) #256x32x32

    a1 = f["decoder_block"][0].float()
    F,C,H,W = a1.shape
    a1 = a1.permute((0, 2, 3, 1)).contiguous().view(F*H,W,C)
    a1 = func.interpolate(a1, size=256)
    a1 = a1.view(F,H,W,256).permute(0,3,1,2) #256x64x64

    feature = func.interpolate(feature, scale_factor=2) + a1 + f["encoder_block"][1].float() #256x64x64

    F,C,H,W = feature.shape
    feature = feature.permute((0, 2, 3, 1)).contiguous().view(F*H,W,C)
    feature = func.interpolate(feature, size=128)
    feature = feature.view(F,H,W,128).permute(0,3,1,2) #128x64x64

    a1 = f["decoder_block"][1].float()
    F,C,H,W = a1.shape
    a1 = a1.permute((0, 2, 3, 1)).contiguous().view(F*H,W,C)
    a1 = func.interpolate(a1, size=128)
    a1 = a1.view(F,H,W,128).permute(0,3,1,2) #128x128x128

    feature = func.interpolate(feature, scale_factor=2) + a1 + f["encoder_block"][0].float() #128x128x128

    a1 = f["decoder_block"][2].float()
    F,C,H,W = a1.shape
    a1 = a1.permute((0, 2, 3, 1)).contiguous().view(F*H,W,C)
    a1 = func.interpolate(a1, size=128)
    a1 = a1.view(F,H,W,128).permute(0,3,1,2) #128x128x128

    feature = func.interpolate(feature, scale_factor=2) + a1 + f["decoder_block"][3].float() #128x256x256

    return feature

In [None]:
import os

for i, data in enumerate(dataloader):
    better_feature = combine_feaures(data[0]["features"])

    data[0]['better_feature'] = better_feature
    feature_path = os.path.join(config['dataset_dir'], 'video_' + str(i) + '.pkl')

    with open(feature_path, 'wb') as better_feature_file:
        pickle.dump(data, better_feature_file)
    