In [2]:
from transformers import AutoImageProcessor, ViTModel
import torch
from PIL import Image
import  os
import matplotlib.pyplot as plt
import pickle
import numpy as np
import pandas as pd

image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

In [3]:
def extract_feature(img_dir, image_processor, model, max_len = 10):
    feats = {}
    files = os.listdir(img_dir)
    
    for f in files[:max_len]:
        f_path = os.path.join(img_dir,f)
        image = Image.open(f_path)
        print("extracting feats of {}".format(f_path))
        inputs = image_processor(image, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)

            last_hidden_states = outputs.last_hidden_state[:,0]
            # feats.append(last_hidden_states)
            feats[f] = last_hidden_states
        print("success extracting feats of {}".format(f_path))
    return feats

def get_result(ori_feats, back_feats, topn):
    result = {}
    for ori_f in ori_feats.keys():
        dist = {}
        for back_f in back_feats.keys():
            d = torch.nn.functional.cosine_similarity(ori_feats[ori_f], back_feats[back_f])
            # d = torch.ao.ns.fx.utils.compute_cosine_similarity(ori_feats[ori_f], back_feats[back_f])
            dist[back_f] = d
        dist = sorted(dist.items(), key=lambda d:d[1], reverse = True)
        result[ori_f] = dist[:topn]
    return result  

def combine_result(back_result, profile_result):
    for ori in back_result.keys():
        back_result[ori].append(profile_result[ori])
    return back_result

def show_result(back_result, profile_result,  ori_path, back_path, profile_path):
    plt.figure()
    row_num = len(back_result)
    fig, axs = plt.subplots(row_num, 5, figsize=(25, 15))
    i = 1
    for r in back_result.keys():
        image = Image.open(os.path.join(ori_path, r))
        plt.subplot(row_num, 5, i)
        i = i + 1
        plt.imshow(image)
        for rr in back_result[r]:
            image = Image.open(os.path.join(back_path, rr[0]))
            plt.subplot(row_num, 5, i)
            i = i + 1
            plt.imshow(image)
         
        image = Image.open(os.path.join(profile_path, profile_result[r][0][0]))
        plt.subplot(row_num, 5, i)
        i = i + 1
        plt.imshow(image)

In [None]:
ori_path = "./测试2/原图_small"
back_path="./测试2/场景图/正面"
profile_path="./测试2/场景图/侧面"

ori_feats = extract_feature(ori_path, image_processor, model, max_len = 100)
with open("./ori_feats", 'wb') as f:
    pickle.dump(ori_feats, f)
    
back_feats = extract_feature(back_path, image_processor, model, max_len = 100)
with open("./back_feats", 'wb') as f:
    pickle.dump(back_feats, f)
    
profile_feats = extract_feature(profile_path, image_processor, model, max_len = 100)
import pickle
with open("./profile_feats", 'wb') as f:
    pickle.dump(profile_feats, f)

In [4]:
with open("./ori_feats", 'rb') as f:
    ori_feats = pickle.load(f)

with open("./back_feats", 'rb') as f:
    back_feats = pickle.load(f)
    
with open("./profile_feats", 'rb') as f:
    profile_feats = pickle.load(f)

In [5]:
def gen_data(ori_feats, back_feats, data_path):
    data = []
    for ori in ori_feats:
        for back in back_feats:
            label = torch.randint(0,2,(1,))
            d = torch.cat([ori_feats[ori][0], back_feats[back][0], label])
            data.append(d)
    ret = torch.stack(data, 0)
    np.savetxt(data_path,ret.numpy(),fmt='%.2f',delimiter=',')
    return ret

In [7]:
import numpy as np
import pandas as pd

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score

 
from numpy import vstack
from pandas import read_csv

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch import Tensor
from torch.nn import Linear
from torch.nn import ReLU
from torch.nn import Sigmoid
from torch.nn import Module
from torch.optim import SGD
from torch.nn import BCELoss
from torch.nn.init import kaiming_uniform_
from torch.nn.init import xavier_uniform_

In [8]:
# dataset definition
class CSVDataset(Dataset):
    # load the dataset
    def __init__(self, data_path):
        # load the csv file as a dataframe
        df = read_csv(data_path, header=None)
        # store the inputs and outputs        
        self.X = df.values[:, :-1]
        self.y = df.values[:, -1]
        # ensure input data is floats
        self.X = self.X.astype('float32')
        # label encode target and ensure the values are floats
        self.y = LabelEncoder().fit_transform(self.y)
        self.y = self.y.astype('float32')
        self.y = self.y.reshape((len(self.y), 1))
 
    # number of rows in the dataset
    def __len__(self):
        return len(self.X)
 
    # get a row at an index
    def __getitem__(self, idx):
        return [self.X[idx], self.y[idx]]
 
    # get indexes for train and test rows
    def get_splits(self, n_test=0.33):
        # determine sizes
        test_size = round(n_test * len(self.X))
        train_size = len(self.X) - test_size
        # calculate the split
        return random_split(self, [train_size, test_size])
 
# model definition
class MLP(Module):
    # define model elements
    def __init__(self, n_inputs):
        super(MLP, self).__init__()
        # input to first hidden layer
        self.hidden1 = Linear(n_inputs, 10)
        kaiming_uniform_(self.hidden1.weight, nonlinearity='relu')
        self.act1 = ReLU()
        # second hidden layer
        self.hidden2 = Linear(10, 8)
        kaiming_uniform_(self.hidden2.weight, nonlinearity='relu')
        self.act2 = ReLU()
        # third hidden layer and output
        self.hidden3 = Linear(8, 1)
        xavier_uniform_(self.hidden3.weight)
        self.hidden4 = Linear(2, 1)
        xavier_uniform_(self.hidden4.weight)
        self.act3 = Sigmoid()
 
    # forward propagate input
    def forward(self, X):
        # input to first hidden layer
        left = X[:,:768]
        right = X[:,768:]
        sim = torch.nn.functional.cosine_similarity(left, right)
        sim = torch.reshape(sim,(-1,1))
        X = self.hidden1(X)
        X = self.act1(X)
         # second hidden layer
        X = self.hidden2(X)
        X = self.act2(X)
        # third hidden layer and output
        X = self.hidden3(X)
#         print(X.shape)
#         print(sim.shape)
        X = torch.cat([X, sim], -1)
#         print(X.shape)
        X = self.hidden4(X)
        X = self.act3(X)
        return X

In [9]:
# prepare the dataset
def prepare_data(path):
    # load the dataset
    dataset = CSVDataset(path)
    # calculate split
    train, test = dataset.get_splits()
    # prepare data loaders
    train_dl = DataLoader(train, batch_size=32, shuffle=True)
    test_dl = DataLoader(test, batch_size=1024, shuffle=False)
    return train_dl, test_dl
 
# train the model
def train_model(train_dl, model):
    # define the optimization
    criterion = BCELoss()
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    # enumerate epochs
    for epoch in range(100):
        # enumerate mini batches
        for i, (inputs, targets) in enumerate(train_dl):
            # clear the gradients
            optimizer.zero_grad()
            # compute the model output
            yhat = model(inputs)
            # calculate loss
            loss = criterion(yhat, targets)
            # credit assignment
            loss.backward()
            # update model weights
            optimizer.step()
 
# evaluate the model
def evaluate_model(test_dl, model):
    predictions, actuals = list(), list()
    for i, (inputs, targets) in enumerate(test_dl):
        # evaluate the model on the test set
        yhat = model(inputs)
        # retrieve numpy array
        yhat = yhat.detach().numpy()
        actual = targets.numpy()
        actual = actual.reshape((len(actual), 1))
        # round to class values
        yhat = yhat.round()
        # store
        predictions.append(yhat)
        actuals.append(actual)
    predictions, actuals = vstack(predictions), vstack(actuals)
    # calculate accuracy
    acc = accuracy_score(actuals, predictions)
    return acc
 
# make a class prediction for one row of data
def predict(ori_feat, back_feats, model):
    ret = []
    for back in back_feats:
        tmp = torch.cat([ori_feat[0], back[0]], -1)
        ret.append(tmp)
    pred_data = torch.stack(ret)
    print(pred_data.shape)
    # make prediction
    yhat = model(pred_data)
    return yhat

In [10]:
# with open("./ori_feats", 'rb') as f:
#     ori_feats = pickle.load(f)

# with open("./back_feats", 'rb') as f:
#     back_feats = pickle.load(f)
    
# with open("./profile_feats", 'rb') as f:
#     profile_feats = pickle.load(f)
train_data_path =  "./embd_data.csv"
train_data = gen_data(ori_feats, back_feats, data_path = train_data_path)

In [11]:
# prepare the data
train_dl, test_dl = prepare_data(train_data_path)
n_inputs = 1536
model = MLP(n_inputs)
# train the model
train_model(train_dl, model)
# evaluate the model
acc = evaluate_model(test_dl, model)
print('Accuracy: %.3f' % acc)
# make a single prediction (expect class=1)

Accuracy: 0.480


In [12]:
MODEL_PATH = "./model.pt" 
torch.save(model, MODEL_PATH)

In [13]:
r_model = torch.load(MODEL_PATH)
r_model.eval()

MLP(
  (hidden1): Linear(in_features=1536, out_features=10, bias=True)
  (act1): ReLU()
  (hidden2): Linear(in_features=10, out_features=8, bias=True)
  (act2): ReLU()
  (hidden3): Linear(in_features=8, out_features=1, bias=True)
  (hidden4): Linear(in_features=2, out_features=1, bias=True)
  (act3): Sigmoid()
)

In [40]:
def get_one_feats(img_name, feats_dict):
    return {img_name: feats_dict[img_name]}

def get_group_feats(img_dir, feats_dict):
    files = os.listdir(img_dir)
    feats = { k : feats_dict[k] for k in files}
    return feats

def cal_score(img_feats, back_feats, topN = 3): 
    ori_img = list(img_feats.keys())[0]
    ori_feas = list(img_feats.values())[0]
    back = []
    key = []
    for k in back_feats:
        back.append(f_back[k])
        key.append(k)
    y = predict(ori_feas, back, model)
    pred_result = {k:v.detach().numpy()[0] for k,v in zip(key, y)}
    ret = sorted(pred_result.items(), key=lambda d:d[1], reverse = True)
    return {ori_img : ret[:topN]}


In [31]:
f_ori = get_one_feats("1.png", ori_feats)

In [32]:
f_back = get_group_feats("./测试2/场景图/正面", back_feats)

In [41]:
cal_score(f_ori, f_back)

torch.Size([15, 1536])


{'1.png': [('客厅装饰画空白样机 (26).jpg', 0.68778807),
  ('客厅装饰画空白样机 (9).jpg', 0.6114947),
  ('客厅装饰画空白样机 (45).jpg', 0.60454434)]}