In [2]:
import glob
import os
from PIL import Image

import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from network import resnet101

In [None]:
class CFG:
    batch_size=8
    img_checkpoint_path = './weights/model_best_img.pth.tar'
    flow_checkpoint_path = './weights/model_best_flow.pth.tar'
    img_dir = '../input/frames'
    flow_dir = '../input/flows'
    img_feature_dir = './output/img_feature'
    flow_featuer_dir = './output/flow_feature'
    channel = 10  # フロー画像の連続フレーム数
    mode = 'rgb'

In [None]:
class Spatial_Dataset(Dataset):
    def __init__(self, root_dir, img_rows=224, img_cols=224, transform=None):
        self.root_dir = root_dir
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.transform = None
        
        self.paths = sorted(glob.glob(os.path.join(root_dir, "*.png")))
        
    def __getitem__(self, index):
        img = Image.open(os.path.join(self.root_dir, f"{index.zfill(6)}.png"))
        img = self.transform(img)
        return img

    def __len__(self):
        return (self.paths)

In [None]:
class Motion_Dataset(Dataset):
    def __init__(self, root_dir, channel=10, img_rows=224, img_cols=224, transform=None):
        self.channel = channel
        self.root_dir = root_dir
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.transform = None
        
        self.paths = sorted(glob.glob(os.path.join(root_dir, "*.npy")))
        
    def __getitem__(self, index):
        flow = torch.FloatTensor(2*self.channel,self.img_rows,self.img_cols)
        for i in range(index*self.channel, (index+1)*self.channel):
            img = np.load(os.path.join(self.root_dir, f"{i.zfill(6)}.npy"))
            img_x = img[...,0]
            img_y = img[...,1]

        img_x = self.transform(img_x)
        img_y = self.transform(img_y)

        flow[2*i, :, :] = img_x
        flow[2*i+1, :, :] = img_y

        return flow
        
    def __len__(self):
        # WRITEME

In [None]:
def spatial_loader(video_id):
    dataset = Spatial_Loader(
        root_dir = os.path.join(CFG.img_dir, video_id),
        transform = transforms.Compose([
                transforms.Scale([224,224]),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                ])
    )
    
    loader = DataLoader(
        dataset=dataset, 
        batch_size=CFG.BATCH_SIZE*4,
        shuffle=False,
        num_workers=os.cpu_count(),
        pin_memory=True
    )

    return laoder

In [None]:
def motion_loader(video_id):
    dataset = Motion_Dataset(
        channel = CFG.channel,
        root_dir = os.path.join(CFG.flow_dir, video_id),
        transform = transforms.Compose([
                        transforms.Scale([224,224]),
                        transforms.ToTensor(),
                    ])
    )

    loader = DataLoader(
        dataset=dataset, 
        batch_size=CFG.BATCH_SIZE,
        shuffle=False,
        num_workers=os.cpu_count(),
        pin_memory=True
    )

    return loader

In [None]:
def feature_extraction(model, x):
    x = model.conv1_custom(x)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)

    x = model.layer1(x)
    x = model.layer2(x)
    x = model.layer3(x)
    x = model.layer4(x)

    x = model.avgpool(x)
    x = x.view(x.size(0), -1)

In [None]:
def extract_video_features(model, loader):
    features = []
    
    for i, flow in enumerate(loader):
        flow = flow.cuda()
        with torch.no_grad():
            feature = feature_extraction(model, flow)
        features.append(np.repeat(feature.cpu().detach().numpy(), CFG.channel, axis=0))  # フレーム数リピートする
    
    features = np.concatenate(features, axis=0)
    print(features.shape)
    
    return features

In [None]:
def main():
    if CFG.mode=="rgb":
        model = resnet101(pretrained=False, channel=3).cuda()
        model.load_state_dict(torch.load(CFG.img_checkpoint_path)['state_dict'])
        model.eval()

        video_ids = os.listdir(CFG.img_dir)

        for video_id in video_ids:
            print(video_id)
            loader = spatial_loader(video_id)
            features = extract_video_feature(model, loader)
            np.save(os.path.join(CFG.img_feature_dir, f"{video_id}.npy"), features)
    
    elif CFG.mode=="flow":
        model = resnet101(pretrained=False, channel=CFG.channel*2).cuda()
        model.load_state_dict(torch.load(CFG.flow_checkpoint_path)['state_dict'])
        model.eval()

        video_ids = os.listdir(CFG.flow_dir)

        for video_id in video_ids:
            print(video_id)
            loader = motion_loader(video_id)
            features = extract_video_feature(model, loader)
            np.save(os.path.join(CFG.flow_feature_dir, f"{video_id}.npy"), features)

In [None]:
main()