In [1]:
import torch
import numpy as np
import cv2
import config
import torch.nn as nn
from tqdm import tqdm
from model import FaceKeypointResNet50
import time
import utils
from dataset import FacialKeypointsDataset
from torch.utils.data import DataLoader

In [2]:
model = FaceKeypointResNet50(pretrained=False, requires_grad=False).to(config.DEVICE)
# load the model checkpoint
checkpoint = torch.load('../outputs/model.pth')
# load model weights state_dict
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print('model loaded successfully')

valid_data=FacialKeypointsDataset(f'{config.ROOT_PATH}/test_frames_keypoints.csv',f'{config.ROOT_PATH}/test')
valid_loader=DataLoader(valid_data,batch_size=config.BATCH_SIZE,shuffle=False)
criterion=nn.SmoothL1Loss()


model loaded successfully


In [3]:
def validate(model, dataloader, data):
    print('Validating')
    model.eval()
    valid_running_loss = 0.0
    counter = 0
    num_batches = int(len(data)/dataloader.batch_size)
    total=0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=num_batches):
            counter += 1
            image, keypoints = data['image'].to(config.DEVICE), data['keypoints'].to(config.DEVICE)
            # flatten the keypoints
            keypoints = keypoints.view(keypoints.size(0), -1)
            start=time.time()
            outputs = model(image)
            end=time.time()
            total+=end-start
            loss = criterion(outputs, keypoints)
            valid_running_loss += loss.item()
            #if  i == 0:
            #    utils.draw_keypoints(image, outputs, keypoints,epoch)
        
    valid_loss = valid_running_loss/counter
    return valid_loss,total


In [4]:
res=validate(model,valid_loader,valid_data)
print(res)

  0%|                                                                                           | 0/12 [00:00<?, ?it/s]

Validating


13it [00:05,  2.24it/s]                                                                                                

(2.065878492135268, 1.2020952701568604)





In [5]:
def validate_(model, dataloader, data):
    print('Validating')
    model.eval()
    valid_running_loss = 0.0
    counter = 0
    num_batches = int(len(data)/dataloader.batch_size)
    total=0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=num_batches):
            counter += 1
            image, keypoints = data['image'].to(config.DEVICE), data['keypoints'].to(config.DEVICE)
            # flatten the keypoints
            keypoints = keypoints.view(keypoints.size(0), -1)
            image=image.to(torch.float16)
            start=time.time()
            outputs = model(image)
            end=time.time()
            total+=end-start
            loss = criterion(outputs, keypoints)
            valid_running_loss += loss.item()
            #if  i == 0:
            #    utils.draw_keypoints(image, outputs, keypoints,epoch)
        
    valid_loss = valid_running_loss/counter
    return valid_loss,total

In [6]:
model_=model.to(torch.float16)
res_=validate_(model_,valid_loader,valid_data)
print(res_)

  0%|                                                                                           | 0/12 [00:00<?, ?it/s]

Validating


13it [00:15,  1.16s/it]                                                                                                

(2.065853870832003, 0.09809064865112305)





In [7]:
from torch.cuda.amp import autocast

ModuleNotFoundError: No module named 'torch.cuda.amp'

In [None]:
def validate_auto(model, dataloader, data):
    print('Validating')
    model.eval()
    valid_running_loss = 0.0
    counter = 0
    num_batches = int(len(data)/dataloader.batch_size)
    total=0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=num_batches):
            counter += 1
            image, keypoints = data['image'].to(config.DEVICE), data['keypoints'].to(config.DEVICE)
            # flatten the keypoints
            keypoints = keypoints.view(keypoints.size(0), -1)
            with autocast:    
                start=time.time()
                outputs = model(image)
                end=time.time()
            total+=end-start
            loss = criterion(outputs, keypoints)
            valid_running_loss += loss.item()
            #if  i == 0:
            #    utils.draw_keypoints(image, outputs, keypoints,epoch)
        
    valid_loss = valid_running_loss/counter
    return valid_loss,total

In [None]:
res_auto=validate_auto(model,valid_loader,valid_data)
print(res_auto)