In [3]:
#Reload an IPython extension by its module name.
%reload_ext autoreload
#2 is the code to reload ALL modules before running any code cell
%autoreload 2 
import cv2
import numpy as np
import os
import pathlib
from fastai.vision import *
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.transforms import Bbox
from matplotlib.backends.backend_agg import FigureCanvasAgg
from PIL import Image as pilImg
%matplotlib inline

In [4]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional")

#Helper functions to read and write images

In [5]:
def imwrite(imageName,img):
    fileName = f'/Users/sandeep/Desktop/FinalYearProject/YoutubeScreenshotScrapper/data/test/{imageName}.png'   
    cv2.imwrite(fileName , img)         

#Loading the court segmentation model

In [6]:
path = Path('/Users/sandeep/Desktop/FinalYearProject/Python Files/courtSegmentation')


In [7]:
codes = np.loadtxt(path/'codes.txt', dtype=str);codes

array(['Other', '2Point', '3Point', 'Board', 'Freethrow', 'Layup'], dtype='<U9')

In [8]:
name2id = {v:k for k,v in enumerate(codes)}
void_code = name2id['Other']

def acc_courtSeg(input, target):
    target = target.squeeze(1)
    mask = target != void_code
    return (input.argmax(dim=1)[mask]==target[mask]).float().mean()

In [9]:
learn = load_learner('/Users/sandeep/Desktop/dataandmodles/models/courtSegmentation/u-net')



#Converting the u-net output to a usable format

In [7]:
fig_args = {"dpi": 72 , 
            "scalefactor": 1
             }
fig = plt.figure()
figCanvasAgg = FigureCanvasAgg(fig)
renderer = figCanvasAgg.get_renderer()
figureImage = mpimg.FigureImage(fig)
def set_figsize(w,h):
    '''
    A helper function for apply_cmap
    note: img.shape= (H,W)
    '''
    xinch = w * fig_args["scalefactor"] / fig_args["dpi"]
    yinch = h * fig_args["scalefactor"] / fig_args["dpi"]
    fig.set_size_inches(xinch , yinch)

def rbga_to_brg(img):
    '''
    Helper function  for apply cmap
    Given: Image array in rbga 
    Returns: image array in brg 
    '''
    #from  rgba to rgb 
    img_rbg = img[0][:,:,0:3];
    img_bgr = np.squeeze(np.dstack([img_rbg[:,:,2] ,img_rbg[:,:,1] , img_rbg[:,:,0]]))    
    return img_bgr

def rotate_img(img):
    '''
    Helper function  for apply cmapp
    Return image that is rotated 180 degree anticlock-wise 
    '''
    return np.rot90(img , k=2)
def reflect_lr(img):
    '''
    Helper function  for apply cmap
    Returns image that is reflected across the y-axis
    '''
    return np.fliplr(img)
    
def apply_cmap(img_mask):
    set_figsize(img_mask.shape[1] , img_mask.shape[0])
    figureImage.set_data(img_mask)
    figureImageData = figureImage.make_image(renderer);
    figureImageData = rbga_to_brg(figureImageData)
    figureImageData = rotate_img(figureImageData)
    figureImageData = reflect_lr(figureImageData)
    return figureImageData
  
def get_class_mask(x):
    return image2np(x[0].data)    
    
    

<Figure size 432x288 with 0 Axes>

#Function that puts everything together

In [None]:
def getCourtMask(img):
    img = cv2.cvtColor(img , cv2.COLOR_BGR2RGB)
    img = Image(pil2tensor(img, dtype=np.float32).div_(255))
    x = learn.predict(img)
    img_mask = get_class_mask(x)
    img_brg = apply_cmap(img_mask)
    return img_brg

#Demo on a video

In [8]:
# videoPath = '/Users/sandeep/Desktop/dataandmodles/data/test2.mov'
# savePath = '/Users/sandeep/Desktop/dataandmodles/u-netDemoImages/lukaDrive'

# cap  = cv2.VideoCapture(videoPath)
# count = 0
# while(cap.isOpened()):
#     ret, img = cap.read()
#     if ret:
#         #converting the image to RGB image before sending it off to model
#         img = cv2.cvtColor(img , cv2.COLOR_BGR2RGB)
#         # converting the image to fastai image object
#         img = Image(pil2tensor(img, dtype=np.float32).div_(255))
#         x = learn.predict(img)
#         img_mask = get_class_mask(x)
#         img_brg = apply_cmap(img_mask)
#         cv2.imwrite(f'{savePath}/{count}.png' ,img_brg)
#         if cv2.waitKey(1) & 0xFF == ord('q'):
#             break
#     else:
#         cap.release()
#         break
#     count += 1 
# cap.release()
# cv2.destroyAllWindows()   

KeyboardInterrupt: 