In [2]:
import torch.utils.data as Data
import torch
import torch.optim as optim
from tqdm import tqdm
import torchvision.transforms as T
from torchvision.datasets import LFWPeople
import os
from scipy import ndimage
import numpy as np
from PIL.ImageEnhance import *
from torchvision.transforms import *
import random
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from models.inception_resnet_v1 import InceptionResnetV1

manualSeed=0
random.seed(manualSeed)
torch.manual_seed(manualSeed)
np.random.seed(manualSeed)
np.random.RandomState(manualSeed)

RandomState(MT19937) at 0x7F64AA865240

In [3]:
lfw_people = LFWPeople(root='/mnt/NAS/home/weicheng/selfLearning/facenet/project/lfw', 
                       download=True, transform=T.Compose([
                T.Resize(size=(160,160)),
                T.RandomHorizontalFlip(0.5),
                T.RandomRotation(degrees=10, interpolation=InterpolationMode.BILINEAR),
                T.ToTensor(),
                T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ]))
# for name in lfw_people.target_names:
#     print(name)

Files already downloaded and verified


In [5]:
resnet = InceptionResnetV1(classify=True, pretrained='casia-webface')
resnet.logits = torch.nn.Linear(512, 5749)
resnet = resnet.cuda(0)
resnet.train()

InceptionResnetV1(
  (conv2d_1a): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_2a): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_2b): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (maxpool_3a): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2d_3b): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_4a): 

In [6]:
loader = Data.DataLoader(
    dataset=lfw_people,
    batch_size=32,
    shuffle=True, drop_last=True)

optimizer = optim.Adam(resnet.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 50])
criterion = torch.nn.CrossEntropyLoss()

In [8]:
def random_rot_flip(image):
    new_image = []
    for i in range(image.shape[0]):
        k = np.random.randint(0, 4)
        image_i = np.rot90(image[i], k)
        # label = np.rot90(label, k)
        axis = np.random.randint(0, 2)
        image_i = np.flip(image_i, axis=axis).copy()
        new_image.append(np.array(image_i))
    image = np.array(new_image, dtype=np.float64)
    return image

def random_rotate(image):
    new_image = []
    for i in range(image.shape[0]):
        angle = np.random.randint(-20, 20)
        image_i = ndimage.rotate(image[i], angle, order=0, reshape=False)
        new_image.append(np.array(image_i))
    image = np.array(new_image, dtype=np.float64)
    return image


def random_noise(image):
    new_image = []
    sigma = random.uniform(0.15, 1.15)
    for i in range(image.shape[0]):
        image_i = ToPILImage()(image[i]).filter(ImageFilter.GaussianBlur(radius=sigma))
        new_image.append(np.array(image_i))
    image = np.array(new_image, dtype=np.float64)
    return image

In [7]:
def accuracy(outputs, labels):
    total = 0
    correct = 0
    _, predicted = torch.max(outputs.data, 1) ##
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
    return correct/total

In [8]:
for epoch in range(100):
    list_loss = []
    list_acc = []
    for step, (batch_x, batch_y) in enumerate(tqdm(loader)):
        inputs, labels = batch_x, batch_y
        inputs, labels = inputs.cuda(0), labels.cuda(0)
        
        pred, fea = resnet(inputs)
        pred_softmax = torch.softmax(pred, dim=-1)
        loss = criterion(pred_softmax, labels)
        list_loss.append(loss.item())
        list_acc.append(accuracy(outputs=pred_softmax, labels=labels))
        loss.backward()
        optimizer.step()
        scheduler.step()
    # break
    print('loss in epoch {} is '.format(epoch), sum(list_loss)/len(list_loss))
    print('acc in epoch {} is '.format(epoch), sum(list_acc)/len(list_acc))
    save_mode_path = os.path.join(
                    '../model', 'epoch_' + str(epoch) + '.pth')
    torch.save(resnet.state_dict(), save_mode_path)

100%|██████████| 413/413 [02:21<00:00,  2.92it/s]


loss in epoch 0 is  8.656644470177898
acc in epoch 0 is  0.017857142857142856


100%|██████████| 413/413 [02:26<00:00,  2.81it/s]


loss in epoch 1 is  8.655076932098906
acc in epoch 1 is  0.02754237288135593


 30%|███       | 124/413 [00:43<01:40,  2.86it/s]


KeyboardInterrupt: 