In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import numpy as np
import os
import shutil
import  opennmt.inputters.record_inputter as inpu
import tensorflow as tf
import torch.utils.data as data
from PIL import Image
import os
import os.path
import pickle
from scipy.spatial.distance import cosine, euclidean,correlation
from sklearn.metrics import accuracy_score, confusion_matrix

In [None]:
# Model Definition

In [None]:
class I(torch.nn.Module):
    def __init__(self):
        super(I, self).__init__()
        
    def forward(self, x):
        return x
    
    def extra_repr(self):
        return 'identity'


class MultiSignClf(torch.nn.Module):
    def __init__(self):
        super(MultiSignClf, self).__init__()
        self.model = torchvision.models.inception_v3(pretrained=True)
        self.model.fc = I()
        self.model.aux_logits = False
        self.fc_out = torch.nn.Linear(in_features=2048, out_features=61)
        self.fc_aux_out = torch.nn.Linear(in_features=2048, out_features=45)
        self.fc_same = torch.nn.Linear(in_features=4096, out_features=1024)
        self.fc_same_hidden = torch.nn.Linear(in_features=1024, out_features=1024)
        self.fc_same_out = torch.nn.Linear(in_features=1024, out_features=2)
    def forward(self, x):
        x = self.model(x)
        l = len(x) // 2
        out = self.fc_out(x[:l])
        aux_out = self.fc_aux_out(x)
        same_input = torch.cat([x[:l],x[l:]],dim=1)
        same_out = self.fc_same(same_input)
        same_out = torch.relu(same_out)
        same_out = self.fc_same_hidden(same_out)
        same_out = torch.relu(same_out)
        same_out = self.fc_same_out(same_out)
        return out, aux_out, same_out
    
    def get_feature(self,x):
        x = self.model(x)
        return x

In [None]:
class ImageFilelist(data.Dataset):
    def __init__(self, image_list, transform=None):
        self.imgs = image_list
        self.transform = transform

    def __getitem__(self, index):
        impath, target = self.imgs[index]
        img = self.img_loader(impath)
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.imgs)
    
    def img_loader(self,path):
        return Image.open(path).convert('RGB')

In [None]:
def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.figure(figsize=(16,8))
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

In [None]:
model_path = ... # Model path  
data_type = 'test' # data type
description = 'right-hand' # data description
data_path = ... # Translation Data

In [None]:
def get_images_from_folder(path,masked=False):
    folders = sorted(os.listdir(path))
    img_paths = []
    count= 0
    for folder in folders:
        full_path = path+'/'+folder + '/' + 'right/'
        imgs = sorted(os.listdir(full_path))
        c = 0
        temp = []
        for img in imgs:
            if img.endswith('.png'):
                name = full_path + img
                name = name.split('/')[-4:]
                name = '/'.join(name)
                if masked:
                    if name in confs:
                        temp.append((full_path + img,count))
                        if confs[name] > 0.4:
                            img_paths.append((full_path + img,count))
                            c += 1
                else:
                    img_paths.append((full_path + img,count))
                    c += 1
        if c == 0:
            img_paths.extend(temp)
        count += 1
    return img_paths

In [None]:
device=torch.device("cuda:0")
img_size = 302
batch_size = 128

In [None]:
trans = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.CenterCrop(img_size-3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
image_list = get_images_from_folder(data_path,masked=False)
data_folder = ImageFilelist(image_list=image_list, transform=trans)
data_loader =torch.utils.data.DataLoader(data_folder, batch_size=batch_size,shuffle=False,num_workers=8)

In [None]:
it = iter(data_loader)

In [None]:
inputs, classes = next(it)

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out)

In [None]:
counts = np.array([x[1] for x in image_list ])
counts = np.unique(counts)

In [None]:
num_img = len(data_folder.imgs)

In [None]:
model = torch.load(model_path,map_location=device)
model = model.eval()
model.to(device)
features_array = np.zeros((num_img,2048))

In [None]:
(features_array == 0).sum() / features_array.reshape(-1).shape[0]

In [None]:
for i,(x,y) in enumerate(data_loader):
    with torch.no_grad():
        features_array[i*batch_size:(i+1)*batch_size,:] = model.get_feature(x.to(device)).cpu().numpy()
    if (i+1) % 100 == 0:
        print(i/len(data_loader),end='\r')

In [None]:
label_list = list(map(lambda x: x[1],data_folder.imgs))
_,label_cnts = np.unique(label_list,return_counts=True)

In [None]:
intervals = label_cnts.cumsum()
intervals = np.insert(intervals,0,0)

In [None]:
file = tf.python_io.TFRecordWriter(data_type+'.tfrecord')
dis_list = list()
for ind in range(intervals.shape[0]-1):
    start,end = intervals[ind:ind+2]
    f = features_array[start:end][::-1]
    inpu.write_sequence_record(f,file)
    #  print(data_folder.imgs[start][0],data_folder.imgs[end][0])
file.close()