In [1]:
# import os

# os.system("pip install pretrainedmodels")

In [2]:
import torch
import numpy as np
import cv2
import config
import torch.nn as nn
from tqdm import tqdm
from model import FaceKeypointResNet50
import time
import utils
from dataset import FacialKeypointsDataset
from torch.utils.data import DataLoader

import torch.nn.utils.prune as prune

In [3]:
model = FaceKeypointResNet50(pretrained=False, requires_grad=False).to(config.DEVICE)
# load the model checkpoint
checkpoint = torch.load('../outputs/model.pth')
# load model weights state_dict
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print('model loaded successfully')
# Number of images is 2308

valid_data=FacialKeypointsDataset(f'{config.ROOT_PATH}/test_frames_keypoints.csv',f'{config.ROOT_PATH}/test')
valid_loader=DataLoader(valid_data,batch_size=1,shuffle=False)
criterion=nn.SmoothL1Loss()


model loaded successfully


In [4]:
def validate(model, dataloader, data, half_precision=False):
    print('Validating')
    model.eval()
    valid_running_loss = 0.0
    counter = 0
    num_batches = int(len(data)/dataloader.batch_size)
    total=0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=num_batches):
            counter += 1
            image, keypoints = data['image'].to(config.DEVICE), data['keypoints'].to(config.DEVICE)
            # flatten the keypoints
            keypoints = keypoints.view(keypoints.size(0), -1)
            if half_precision:
                image=image.to(torch.float16)
            start=time.time()
            outputs = model(image)
            end=time.time()
            total+=end-start
            loss = criterion(outputs, keypoints)
            valid_running_loss += loss.item()
            #if  i == 0:
            #    utils.draw_keypoints(image, outputs, keypoints,epoch)
    valid_loss = valid_running_loss/counter
    return valid_loss,total, counter


In [5]:
# Original Model
valid_loss, total_time, counter = validate(model,valid_loader,valid_data)
print("Validation loss is: " + str(valid_loss))
print("Total time is: " + str(total_time))
print("Fps is: " + str(2308 / total_time))

  0%|                                                                                          | 0/770 [00:00<?, ?it/s]

Validating


100%|████████████████████████████████████████████████████████████████████████████████| 770/770 [00:10<00:00, 71.47it/s]

Validation loss is: 2.0754783021939267
Total time is: 6.416483640670776
Fps is: 359.69857156196576





In [6]:
# Half precision
model_=model.to(torch.float16)
valid_loss, total_time, counter = validate(model_,valid_loader,valid_data, half_precision=True)
print("Validation loss is: " + str(valid_loss))
print("Total time is: " + str(total_time) + "s")
print("Fps is: " + str(2308 / total_time))

  0%|▏                                                                                 | 2/770 [00:00<00:40, 19.03it/s]

Validating


100%|████████████████████████████████████████████████████████████████████████████████| 770/770 [00:40<00:00, 18.84it/s]

Validation loss is: 2.075412207919282
Total time is: 6.090193748474121s
Fps is: 378.9698809792812





In [7]:
model_prune = FaceKeypointResNet50(pretrained=False, requires_grad=False).to(config.DEVICE)
checkpoint_prune = torch.load('../outputs/model_prune_0.2_new.pth')
model_prune.load_state_dict(checkpoint['model_state_dict'])
model_prune.eval()
print('model loaded successfully')

model loaded successfully




In [8]:
# Pruned model
valid_loss, total_time, counter = validate(model_prune,valid_loader,valid_data)
print("Validation loss is: " + str(valid_loss))
print("Total time is: " + str(total_time) + "s")
print("Fps is: " + str(2308 / total_time))

  1%|▊                                                                                 | 8/770 [00:00<00:09, 76.85it/s]

Validating


100%|████████████████████████████████████████████████████████████████████████████████| 770/770 [00:09<00:00, 78.96it/s]

Validation loss is: 2.0754783021939267
Total time is: 5.616914987564087s
Fps is: 410.9017147508798





In [9]:
print(counter)

770


In [10]:
# state_dict = torch.load('../outputs/model_prune_0.2.pth', map_location="cpu")
# torch.save(state_dict, '../outputs/model_prune_0.2_new.pth', _use_new_zipfile_serialization=False)

In [11]:
# # Remove the redundant weights
# model = FaceKeypointResNet50(pretrained=False, requires_grad=False, pruning_amount=0.2)
# checkpoint = torch.load('../outputs/model_prune_0.2_new.pth')
# # load model weights state_dict
# model.load_state_dict(checkpoint['model_state_dict'])


# # Iterate over the convolutional layers and remove pruning reparametrization
# for module in model.model.modules():
#     if isinstance(module, nn.Conv2d):
#         prune.remove(module, 'weight')

# # Save the updated model
# torch.save(model.state_dict(), '../outputs/model_prune_0.2_new_removed.pth', _use_new_zipfile_serialization=False)


In [12]:
# # Save the original model into only the model to see the differences
# model = FaceKeypointResNet50(pretrained=False, requires_grad=False)
# checkpoint = torch.load('../outputs/model.pth')
# # load model weights state_dict
# model.load_state_dict(checkpoint['model_state_dict'])

# # Save the updated model
# torch.save(model.state_dict(), '../outputs/model_new.pth', _use_new_zipfile_serialization=False)