In [94]:
import logging
import os.path as osp
from glob import glob
from os import makedirs
from sys import stdout

import numpy as np

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import torchvision.models as models
from tqdm import tqdm
import scipy.io

def choose_device(use_cuda=True):
    cuda = use_cuda and torch.cuda.is_available()
    device = torch.device("cuda"if cuda else "cpu")
    print('Using:'+str(device))
    return device

def get_network(device):
    net = models.vgg19(pretrained=True)
    # net = models.alexnet(pretrained=True)
    # net = models.densenet121(pretrained=True)
    net.classifier = nn.Sequential(*list(net.classifier.children())[:-1])
    return net.to(device)

def get_xforms():
    normalize = transforms.Normalize(mean=[0.485,0.456,0.406],
                                     std=[0.229,0.224,0.225])
    xforms = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    return xforms

def get_dataloader(im_dir,xforms,batch_size=4,num_workers=4):
    data = DataLoader(ImageFolder(im_dir,xforms),
                     batch_size = batch_size,
                     shuffle=False,
                     num_workers = num_workers)
    return data

def get_features(raw_image,network,device):
    features = []
    for batch_idx,data in tqdm(enumerate(raw_image)):
        input,target = data[0].to(device),data[1].to(device)
        features.append(network(input))
    return np.vstack([f.cpu().data.numpy() for f in features])

In [95]:
def pipeline(im_dir):
    with torch.no_grad():
        device = choose_device()
        network = get_network(device)
        network.eval() 
        raw_image = get_dataloader(im_dir,get_xforms())
        features = get_features(raw_image, network, device)
        scipy.io.savemat('wiki_train_img_vgg19.mat',mdict={'train_img':features})

In [96]:


if __name__ == '__main__':
    pipeline('./raw_image_wiki/train/')

Using:cuda


544it [00:16, 32.85it/s]
