In [1]:
# Import Moudles and Packages
import os
import time
import random

import warnings
warnings.filterwarnings("ignore")

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
plt.rcParams["axes.grid"]=False

# Import pytorch modules
import torch
import torch.nn as nn
from torch.utils import data as D

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.cuda.amp import autocast, GradScaler
from madgrad import MADGRAD

device = 'cuda:2' if torch.cuda.is_available() else 'cpu'

print("| Pytorch version: {}".format(torch.__version__))
print("| GPU: {}".format(torch.cuda.is_available()))
print("| Device : ",device)
print("| Device name: ", torch.cuda.get_device_name(0))
print("| Device count: ", torch.cuda.device_count())


| Pytorch version: 1.10.0+cu102
| GPU: True
| Device :  cuda:2
| Device name:  Tesla T4
| Device count:  4


In [2]:
# Import local modules
from src import config as C
from src.models import hrnet, resnet, basenet

from src.utils.collate_fn import *
from src.utils.print_overwrite import *
from src.utils.seed import *
from src.utils.view_kypoints import *

from src.dataset.kface_dataset import *
from src.dataset.album_transform import *
from src.dataset import w300_dataset as W

seed_everything(C.SEED)

print(f"| Number of image : {len(C.IMAGE_LIST)}")
print(f"| Number of label : {len(C.LABEL_LIST)}")
print(f"| Number of trainset : {C.LEN_TRAIN_SET}")
print(f"| Number of validset : {C.LEN_VALID_SET}")

| Number of image : 100
| Number of label : 100
| Number of trainset : 90
| Number of validset : 10


In [3]:
w_dataset = W.FaceLandmarksDataset(W.Transforms())

# # (Option) visualize dataset
# for idx, (image, landmarks) in enumerate(w_dataset): # torch.Size([1, 512, 512])
#     if idx == 1:
#         vis_keypoints(image, landmarks)

In [4]:
len_valid_set = int(0.1*len(w_dataset))
len_train_set = len(w_dataset) - len_valid_set

print("The length of Train set is {}".format(len_train_set))
print("The length of Valid set is {}".format(len_valid_set))

train_dataset , valid_dataset,  = torch.utils.data.random_split(w_dataset , [len_train_set, len_valid_set])

# shuffle and batch the datasets
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=4)

train_images, train_landmarks = next(iter(train_loader))
valid_images, valid_landmarks = next(iter(valid_loader))

print(f"| Size of image in train_loader : {train_images.shape}")
print(f"| Size of label in train_loader : {train_landmarks.shape}")
print(f"| Size of image in train_loader : {valid_images.shape}")
print(f"| Size of label in train_loader : {valid_landmarks.shape}")

The length of Train set is 6000
The length of Valid set is 666
| Size of image in train_loader : torch.Size([16, 1, 512, 512])
| Size of label in train_loader : torch.Size([16, 27, 2])
| Size of image in train_loader : torch.Size([2, 1, 512, 512])
| Size of label in train_loader : torch.Size([2, 27, 2])


In [5]:
from torchvision import models

class resnext50_32x4d(nn.Module):
    def __init__(self,num_classes=54):
        super().__init__()
        self.model_name='resnext50_32x4d'
        self.model=models.resnext50_32x4d(pretrained=True)
    
        self.model.conv1=nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=3, bias=False)
        self.model.fc=nn.Linear(self.model.fc.in_features, num_classes)
        
    def forward(self, x):
        x = self.model(x)
        
        return x

In [6]:
model = resnext50_32x4d()
# model = model.cpu()
# model.load_state_dict(torch.load("/home/yuyuqw92/FLD-scratch/result/face_landmarks.pth",map_location='cpu'))
model = model.to(device)

In [7]:
x = torch.randn([1, 1, 512, 512]).to(device)
out = model(x).to(device)
print(f"input : {x.shape} | output : {out.size()}")

input : torch.Size([1, 1, 512, 512]) | output : torch.Size([1, 54])


In [8]:
criterion = nn.MSELoss()
optimizer = MADGRAD(params=model.parameters(), lr=C.LEARNING_RATE, weight_decay=C.WEIGHT_DECAY)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=C.EPOCHS, T_mult=1)

In [9]:
from tqdm import tqdm

loss_min = np.inf

start_time = time.time()
for epoch in range(C.EPOCHS):
    
    loss_train = 0
    loss_valid = 0
    running_loss = 0
    
    model.train()
    for step, (images, landmarks) in enumerate(train_loader):
        model = model.to(device)    

        # images, landmarks = next(iter(train_loader))

        images = images.to(device)
        landmarks = landmarks.view([landmarks.size(0),-1]).to(device)
        
        predictions = model(images)
        
        # clear all the gradients before calculating them
        optimizer.zero_grad()
        
        # find the loss for the current step
        loss_train_step = criterion(predictions, landmarks)
        
        # calculate the gradients
        loss_train_step.backward()
        
        # update the parameters
        optimizer.step()
        
        loss_train += loss_train_step.item()
        running_loss = loss_train/(step+1)
        
        print_overwrite(step, len(train_loader), running_loss, 'train')
        
    model.eval() 
    with torch.no_grad():
        
        for step in range(1,len(valid_loader)+1):
            
            images, landmarks = next(iter(valid_loader))
      
            images = images.to(device)
            landmarks = landmarks.view(landmarks.size(0),-1).to(device)
        
            predictions = model(images).to(device)
                    
            # find the loss for the current step
            loss_valid_step = criterion(predictions, landmarks)

            loss_valid += loss_valid_step.item()
            running_loss = loss_valid/(step+1)

            print_overwrite(step, len(valid_loader), running_loss, 'valid')

    loss_train /= len(train_loader)
    loss_valid /= len(valid_loader)
    
    print('\n--------------------------------------------------')
    print('Epoch: {}  Train Loss: {:.4f}  Valid Loss: {:.4f}'.format(epoch, loss_train, loss_valid))
    print('\n--------------------------------------------------')
    
    if loss_valid < loss_min:
        loss_min = loss_valid
        torch.save(model.state_dict(), '/home/ubuntu/workspace/FLD-scratch/result/face_landmarks.pth') 
        print("\nMinimum Validation Loss of {:.4f} at epoch {}/{}".format(loss_min, epoch, C.EPOCHS))
        print('Model Saved\n')

print('Training Complete')
print("Total Elapsed Time : {} s".format(time.time()-start_time))

Valid Steps: 333/333  Loss: 0.0022 
--------------------------------------------------
Epoch: 0  Train Loss: 0.0085  Valid Loss: 0.0023

--------------------------------------------------

Minimum Validation Loss of 0.0023 at epoch 0/100
Model Saved

Valid Steps: 333/333  Loss: 0.0021 
--------------------------------------------------
Epoch: 1  Train Loss: 0.0033  Valid Loss: 0.0021

--------------------------------------------------

Minimum Validation Loss of 0.0021 at epoch 1/100
Model Saved

Valid Steps: 96/333  Loss: 0.0012  

In [None]:
start_time = time.time()

with torch.no_grad():

    best_network = resnext50_32x4d()
    best_network.to(device)
    best_network.load_state_dict(torch.load('/home/ubuntu/workspace/FLD-scratch/result/face_landmarks.pth')) 
    best_network.eval()
    
    images, landmarks = next(iter(valid_loader))
    
    images = images.to(device)
    landmarks *= 512 #* torch.tensor((512,512))    

    predictions = best_network(images).cpu()
    predictions = predictions.view([-1,27,2])
    predictions *= 512# * torch.tensor((512,512))    
    
    plt.figure(figsize=(10,20))
    
    for img_num in range(2):
        print(images[img_num].shape)
        plt.subplot(2,1,img_num+1)
        plt.imshow(images[img_num].cpu().permute(1,2,0).squeeze(), cmap='gray')
        plt.scatter(predictions[img_num].T[0], predictions[img_num].T[1], c = 'r', s = 5)
        plt.scatter(landmarks[img_num].T[0], landmarks[img_num].T[1], c = 'g', s = 5)

print('Total number of test images: {}'.format(len(valid_dataset)))

end_time = time.time()
print("Elapsed Time : {}".format(end_time - start_time)) 