In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as torchv
import numpy as np
import mymodels
import mycoco_small_xian as mycoco
import torchvision.transforms as transforms
import torch.utils.data as data
from tensorboardX import SummaryWriter
import torchvision.utils as vutils
import os
import pdb
from models.model_AE_VGG16_2 import AE_VGG16_2 as AE_VGG16

from random import randint
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import matplotlib.pyplot as plt
%matplotlib inline

from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torch import nn
from copy import deepcopy
import datetime
import time

# parameters

In [2]:
param = {"datasetClasses": ['horse','zebra'],
        "check_point_path":'./Experiments/horseandzebra/weight/last_weight.pkl',
        "resultCount":10}

In [3]:
IMAGE_SIZE = 224
DATASET_NAMES = param["datasetClasses"]
BATCH_SIZE = 2
cuda = torch.cuda.is_available()

In [4]:
def create_dataset(name='train', batch_size=32):
    dataDir='/scratch/cluster-share/linzhe/cocoDataset/'
    dataType='%s2017'%name
    annFile='{}/annotations/instances_{}.json'.format(dataDir,dataType)
    root = '%s%s'%(dataDir, dataType)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    trans = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                transforms.ToTensor()])
    target_transform = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                           transforms.ToTensor()
                                           ]) # normalize to be in [-1, 1]
    coco = mycoco.CocoDetection(root, annFile, transform=trans, target_transform=target_transform, categories=DATASET_NAMES)
    cocoloader = data.DataLoader(coco, batch_size, num_workers=4, shuffle=True)
    
    return cocoloader, coco

In [5]:
cocoloader, cocoDataset = create_dataset(name='train', batch_size=BATCH_SIZE)
np.unique(cocoDataset.categories)

loading annotations into memory...
Done (t=19.30s)
creating index...
index created!


array(['horse', 'zebra'], dtype='<U5')

In [6]:
def get_bg_obj(image, mask):
    obj = np.multiply(image,mask)
    
    bg_mask = mask.clone() - 1
    bg_mask[bg_mask == -1] = 1
    
    bg = np.multiply(image,bg_mask)

    if cuda:
        bg = Variable(bg).cuda()
        mask = Variable(mask).cuda()
        obj = Variable(obj).cuda()
    else:
        bg = Variable(bg).cuda()
        mask = Variable(mask).cuda()
        obj = Variable(obj).cuda()
    
    return bg, mask, obj

In [7]:
def swichObject(model_weight_path, image1, mask1, image2, mask2, model, pair_num):
    
    bg1, mask1, obj1 = get_bg_obj(image1, mask1)
    bg2, mask2, obj2 = get_bg_obj(image2, mask2)    
    
    print(obj1.shape, mask2.shape)
    _, obj2_new = model(obj1, mask2)
    _, obj1_new = model(obj2, mask1)
    
    
    #print(obj1_new.shape, bg1.shape)
    #obj1_new *= (bg1 != 0).type(torch.cuda.FloatTensor)
    new_image1 = bg1 + obj1_new
    new_image2 = bg2 + obj2_new
    
    
    ts = time.time()
    st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S')
    
    image1 = Variable(image1).cuda()
    image2 = Variable(image2).cuda()
        
    output = torch.cat((image1.data, obj1.data, bg1.data, obj1_new.data, new_image1.data,
                        image2.data, obj2.data, bg2.data, obj2_new.data, new_image2.data), dim = 0)
    
    print(output.shape)
    save_image(output, './swich_result_mask/' + st + '_result_' + str(pair_num) + '.png')
    '''
    save_image(mask1.data, './swich_result_mask/' + st + '_mask1.png')
    save_image(mask2.data, './swich_result_mask/' + st + '_mask2.png')
    '''

In [8]:
def output_resilt(check_point_path, pair_num = 1):
    model = AE_VGG16(IMAGE_SIZE)
    #dae.load('weights_mask/dae-999.pkl')
    #quick fix delete module to all keys
    d2 =  {str(k)[7:]:v for k,v in torch.load(check_point_path).items()}
    model.load_state_dict(d2)
    model.eval().cuda()
    
    for j in range(pair_num):
        for i, (images, masks) in enumerate(cocoloader):
            image1 = images[0].view(1,images[0].shape[0],images[0].shape[1],images[0].shape[2])
            image2 = images[1].view(1,images[0].shape[0],images[0].shape[1],images[0].shape[2])
            mask1 = masks[0].view(1,1,images[0].shape[1],images[0].shape[2])
            mask2 = masks[1].view(1,1,images[0].shape[1],images[0].shape[2])
            break
        
        swichObject(check_point_path, image1, mask1, image2, mask2, model, j)


In [9]:
output_resilt(check_point_path=param["check_point_path"], pair_num=param["resultCount"])

torch.Size([1, 3, 224, 224]) torch.Size([1, 1, 224, 224])
torch.Size([10, 3, 224, 224])
torch.Size([1, 3, 224, 224]) torch.Size([1, 1, 224, 224])
torch.Size([10, 3, 224, 224])
torch.Size([1, 3, 224, 224]) torch.Size([1, 1, 224, 224])
torch.Size([10, 3, 224, 224])
torch.Size([1, 3, 224, 224]) torch.Size([1, 1, 224, 224])
torch.Size([10, 3, 224, 224])
torch.Size([1, 3, 224, 224]) torch.Size([1, 1, 224, 224])
torch.Size([10, 3, 224, 224])
torch.Size([1, 3, 224, 224]) torch.Size([1, 1, 224, 224])
torch.Size([10, 3, 224, 224])
torch.Size([1, 3, 224, 224]) torch.Size([1, 1, 224, 224])
torch.Size([10, 3, 224, 224])
torch.Size([1, 3, 224, 224]) torch.Size([1, 1, 224, 224])
torch.Size([10, 3, 224, 224])
torch.Size([1, 3, 224, 224]) torch.Size([1, 1, 224, 224])
torch.Size([10, 3, 224, 224])
torch.Size([1, 3, 224, 224]) torch.Size([1, 1, 224, 224])
torch.Size([10, 3, 224, 224])
