In [26]:
from torchvision import models
from torchvision import transforms
import numpy as np  
import torch
import torch.nn
import torchvision
from torch.autograd import Variable
from torchvision import transforms
import PIL
#import the PIL library for image handling
from PIL import Image
from torchvision import io
import os
import av


#initialize the device for GPU/CPU support
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#initialize the transforms - these are like pre steps to do on the image once we get it.
data_transforms = transforms.Compose(
    [
        transforms.Resize(196),
        transforms.CenterCrop(196),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    ]
)

#load the network we want to try(use a pre trained network)
alexnet = models.alexnet(pretrained=True)

def preprocessImageTo3DTensor(inputImage):
    image = inputImage
    if (torch.is_tensor(inputImage)):
        image = PIL.Image.fromarray(image) #Webcam frames are numpy array format                          
    image = data_transforms(image)
    image = image.float()
    image = image.unsqueeze(0) #I don't know for sure but Resnet-50 model seems to only
                               #accepts 4-D Vector Tensor so we need to squeeze another
    return image  

def displayImage(imageTensor):
    to_pil = transforms.ToPILImage()
    #remove one dimension as display only takes 3 dimensional tensor
    img_pil = to_pil(imageTensor.squeeze(0))
    img_pil.show()
    return

def evaluateTensor(inputTensor):
    #Alexnet model seems to only accpets 4-D Vector Tensor so we need to squeeze
    #batch_t = inputTensor.unsqueeze(0)
    alexnet.eval()
    return alexnet(inputTensor)
   
def processSingleImage(imgRaw):
    imgTensor = preprocessImageTo3DTensor(imgRaw)

    #displayImage(imgTensor)
    out = evaluateTensor(imgTensor)

    #read the data labels so we can take the out from above and identify the labels for the output
    with open("imagenet_classes.txt") as f:
      labels = [line.strip() for line in f.readlines()]

    #print out the label and the percentage match
    _, index = torch.max(out, 1)
    percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
    return(labels[index[0]],percentage[index[0]].item())

def processImageFromFile():
    # Reading a raw image from a file.
    labels, score = processSingleImage(Image.open("rat.jpg"))
    print("Labels: ", labels, "Score: ", score)

def processVideoFromFile():
    videoContainer = av.open("elephant.mp4")
    frameCount = 0
    totalPercentDict = {}
    for frame in videoContainer.decode(video=0):
        imgRaw = frame.to_image()
        labels, score = processSingleImage(imgRaw)
        #if(score > 35.0):
         #   imgRaw.show()
        
        #print("Frames: ", frameCount, " Labels: ", labels, " Score: ", score)
        if(labels in totalPercentDict):
            prePercent = totalPercentDict.get(labels)
            totalPercentDict[labels] = prePercent + score
            
        else: totalPercentDict[labels] = score
        frameCount += 1
    #print(totalPercentDict)
    denominator = 0
    maxPercent = 0
    maxPercentLabel = ""
    for label, percent in totalPercentDict.items():
        denominator += percent
    for label, percent in totalPercentDict.items():
        percentFrac = percent/denominator
        if(percentFrac > maxPercent):
            maxPercent = percentFrac
            maxPercentLabel = label
    print("Label: ", maxPercentLabel, " Percent: ", 100*maxPercent)
        



    
    
processVideoFromFile()


Label:  101: 'tusker',  Percent:  88.94282565992087
