In [1]:
import torch
import torchvision
import os
import pandas as pd
import numpy as np
from skimage import io
import matplotlib.pyplot as plt
import tqdm
from torch import nn
import torch.nn.functional as F
from skimage import transform
from torch.autograd import Variable
import logging
from collections import Counter
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.model_selection import train_test_split

In [2]:
with open('../eval/identity_CelebA.txt') as file:
    labels = []
    fnames = []
    for line in file:
        fields = line.strip().split()
        fnames.append(fields[0])
        labels.append(int(fields[1]))
    labels = np.array(labels)
    fnames = np.array(fnames)
f_ids = dict(zip(fnames, range(len(fnames))))

In [3]:
class TripleBatchGen:
    def __init__(self, X, y, batch_size):
        self.X = X
        self.y = y
        self.batch_size = batch_size
        self.yvals = np.unique(y)
        counts = Counter(y)
        self.ybig = np.array([value for value, count in counts.items() if count > 2 * batch_size])
        self.yids = dict()
        for yval in self.yvals:
            self.yids[yval] = np.where(y == yval)[0]
        self.y_train, self.y_test = train_test_split(self.ybig, test_size=0.01, random_state=1234)
        yset = set(self.y_train)
        self.train_ids = np.array([idx for idx in range(len(self.y)) if self.y[idx] in yset])
            
    def generate_batches(self, n_batches):
        for _ in range(n_batches):
            an_label = np.random.choice(self.y_train)
            np.random.shuffle(self.yids[an_label])
            an_ids = self.yids[an_label][:self.batch_size]
            pos_ids = self.yids[an_label][self.batch_size: 2 * self.batch_size]
            neg_ids = self.train_ids[np.random.randint(0, len(self.train_ids), self.batch_size)]
            while np.any(self.y[neg_ids] == an_label):
                neg_ids = np.random.randint(0, len(self.y), self.batch_size)
            yield self.X[an_ids], self.X[pos_ids], self.X[neg_ids]
            
    def test_batch(self):
        an_label = np.random.choice(self.y_test)
        np.random.shuffle(self.yids[an_label])
        an_ids = self.yids[an_label][:self.batch_size]
        pos_ids = self.yids[an_label][self.batch_size: 2 * self.batch_size]
        neg_ids = self.train_ids[np.random.randint(0, len(self.train_ids), self.batch_size)]
        return self.X[an_ids], self.X[pos_ids], self.X[neg_ids]

In [4]:
batch_gen = TripleBatchGen(fnames, labels, 10)

In [5]:
def load_ims(fnames, addr='../img_align_celeba/'):
    ims = []
    for fname in fnames:
        ims.append(io.imread(addr + fname))
    return np.stack(ims)

In [6]:
def ran_transform(X):
    height = X.shape[1] 
    width = X.shape[2]
    X_out = np.zeros((X.shape[0], height, width, 3), dtype=np.float32)
    
    for idx, x in enumerate(X):
        x = transform.resize(x, (height, width))
        angle = np.random.uniform(-10, 10)
        h_scale = np.random.randint(0, 10)
        v_scale = np.random.randint(0, 10)
        h_pos = np.random.randint(0, h_scale + 1)
        v_pos = np.random.randint(0, v_scale + 1)
        x = transform.rotate(x, angle)
        x = transform.resize(x, (height + v_scale, width + h_scale))
        x = x[v_pos:height + v_pos, h_pos: width + h_pos]
        X_out[idx] = x
    X_out = X_out.transpose([0, 3, 1, 2])
    return X_out

def test_transform(X):
    height = X.shape[1]
    width = X.shape[2] 
    X_out = np.zeros((X.shape[0], height, width, 3), dtype=np.float32)
    for idx, x in enumerate(X):
        x = transform.resize(x, (height, width))
        X_out[idx] = x
    X_out = X_out.transpose([0, 3, 1, 2])
    return X_out

In [11]:
class MyVGG(nn.Module):
    
    def __init__(self):
        super(MyVGG, self).__init__()
        self.features = torchvision.models.squeezenet1_0(pretrained=False).features
        #for par in self.features.parameters():
        #    par.requires_grad=False
        #15360
        #66560
        self.embeddings = nn.Sequential(nn.Linear(66560, 1024), nn.ReLU(inplace=True), nn.Dropout(0.5),
                                        #nn.Linear(2048, 1024), nn.ReLU(inplace=True), nn.Dropout(0.5),
                                        nn.Linear(1024, 128))
        self.classif = nn.Sequential(self.embeddings, nn.ReLU(), nn.Linear(128, 40))
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.shape[0], -1)
        # print(x.shape)
        return F.normalize(self.embeddings(x), dim=1)
    

In [12]:
class Logger:
    def __init__(self, filename):
        self.filename = filename
        self.logstack = []
        
    def log(self, text):
        self.logstack.append(text)
        
    def flush(self):

        with open(self.filename, 'a') as f:
            for text in self.logstack:
                f.write(text + '\n')
        self.logstack = []
        

In [13]:
torch.cuda.set_device(3)

In [14]:
# net = MyVGG().cuda()
# optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)
# criterion = F.triplet_margin_loss

In [15]:
epoch = 0

In [16]:
net = MyVGG().cuda()
net.load_state_dict(torch.load('model_tr'))
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)
criterion = F.triplet_margin_loss

In [17]:
batch_size = batch_gen.batch_size
epoch_size = 50
n_epochs = 3000
logger = Logger('logs2.txt')
curloss = 100000
for n in range(n_epochs):
    net.train()
    for b_id, (Xanc, Xpos, Xneg) in enumerate(batch_gen.generate_batches(epoch_size)):
        x = np.concatenate((Xanc, Xpos, Xneg))
        x = ran_transform(load_ims(x))
        x = Variable(torch.FloatTensor(x).cuda())
        optimizer.zero_grad()
        out = net(x)
        loss = criterion(out[:batch_size], out[batch_size: 2 * batch_size], out[2 * batch_size:])
        loss.backward()
        optimizer.step()
        lossv = loss.data.cpu().numpy()[0]
        logger.log(' '.join(['train', str(epoch), str(b_id), str(lossv)]))
    net.eval()
    Xanc, Xpos, Xneg = batch_gen.test_batch()
    x = np.concatenate((Xanc, Xpos, Xneg))
    x = test_transform(load_ims(x))
    x = Variable(torch.FloatTensor(x).cuda())
    out = net(x)
    loss = criterion(out[:batch_size], out[batch_size: 2 * batch_size], out[2 * batch_size:])
    lossv = loss.data.cpu().numpy()[0]
    logger.log(' '.join(['test', str(epoch), str(lossv)]))
    if lossv < curloss:
        curloss = lossv
        torch.save(net.state_dict(), 'model_tr')
    epoch += 1
    logger.flush()
        

  warn("The default mode, 'constant', will be changed to 'reflect' in "
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-17-c27a348c4e97>", line 10, in <module>
    x = ran_transform(load_ims(x))
  File "<ipython-input-6-c4f6d6942448>", line 7, in ran_transform
    x = transform.resize(x, (height, width))
  File "/usr/local/lib/python3.5/dist-packages/skimage/transform/_warps.py", line 135, in resize
    preserve_range=preserve_range)
  File "/usr/local/lib/python3.5/dist-packages/skimage/transform/_warps.py", line 775, in warp
    order=order, mode=mode, cval=cval))
  File "skimage/transform/_warps_cy.pyx", line 131, in skimage.transform._warps_cy._warp_fast
  File "/usr/local/lib/python3.5/dist-packages/numpy/core/numeric.py", line 424, in asarray
    def asarray(a, dtype=None, order=None):
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (m

KeyboardInterrupt: 

In [18]:
net = net.cpu()

In [19]:
torch.save(net.state_dict(), 'model_tr1')