In [None]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torchvision import models
from torch.utils.data import DataLoader

import cv2 as cv
from PIL import Image
from tqdm import tqdm
import glob
import matplotlib.pyplot as plt

In [None]:
deeplab = models.segmentation.deeplabv3_mobilenet_v3_large(pretrained=0, 
                                                           progress=1, 
                                                           num_classes=2)

In [None]:
class FootballModel(nn.Module):
    def __init__(self):
        super(FootballModel,self).__init__()
        self.dl = deeplab
        
    def forward(self, x):
        y = self.dl(x)['out']
        return y

In [None]:
class SegDataset:
    
    def __init__(self, parentDir, imageDir, maskDir):
        self.imageList = glob.glob(parentDir+'/'+imageDir+'/*')
        self.maskList = glob.glob(parentDir+'/'+maskDir+'/*')
    def __getitem__(self, index):
        
        preprocess = transforms.Compose([
                                    transforms.Resize((640, 380), 2),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
        
        X = Image.open(self.imageList[index]).convert('RGB')
        X = preprocess(X)
        
        trfresize = transforms.Resize((640, 380), 2)
        trftensor = transforms.ToTensor()
        
        yimg = Image.open(self.maskList[index]).convert('L')
        y1 = trftensor(trfresize(yimg))
        y1 = y1.type(torch.BoolTensor)
        y2 = torch.bitwise_not(y1)
        y = torch.cat([y2, y1], dim=0)
        
        return X, y
            
    def __len__(self):
        return len(self.imageList)

In [None]:
train_dataset = SegDataset('foosball_dataset', 'train_set', 'train_set_mask')
validat_set = SegDataset('foosball_dataset', 'val_set', 'val_set_mask')
test_set = SegDataset('foosball_dataset', 'test_set', 'test_set_mask')

In [None]:
trainLoader = DataLoader(train_dataset, batch_size = 16, shuffle=False)
valLoader = DataLoader(validat_set, batch_size = 16, shuffle=False)

In [None]:
def pixe_acc(target, predicted):    
    if target.shape != predicted.shape:
        print("target has dimension", target.shape, ", predicted values have shape", predicted.shape)
        return
        
    if target.dim() != 4:
        print("target has dim", target.dim(), ", Must be 4.")
        return
    
    accsum=0
    for i in range(target.shape[0]):
        target_arr = target[i, :, :, :].clone().detach().cpu().numpy().argmax(0)
        predicted_arr = predicted[i, :, :, :].clone().detach().cpu().numpy().argmax(0)
        
        same = (target_arr == predicted_arr).sum()
        a, b = target_arr.shape
        total = a*b
        accsum += same/total
    
    pixel_accuracy = accsum/target.shape[0]        
    return pixel_accuracy

In [None]:
model = FootballModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.BCEWithLogitsLoss()
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.8)

In [None]:
def training_loop(n_epochs, optimizer, lr_scheduler, model, loss_fn, train_loader, val_loader, lastCkptPath = None):
    if torch.cuda.is_available():  
        dev = "cuda:0" 
    else:  
        dev = "cpu"
    device = torch.device(dev)
    model.to(device)
    
    tr_loss_arr = []
    val_loss_arr = []
    pixelacctrain = []
    pixelacctest = []
    prevEpoch = 0
    
    if lastCkptPath != None :
        checkpoint = torch.load(lastCkptPath)
        prevEpoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
                    tr_loss_arr =  checkpoint['Training Loss']
                    
        val_loss_arr =  checkpoint['Validation Loss']
        pixelacctrain =  checkpoint['PixelAcc train']
        pixelacctest =  checkpoint['PixelAcc test']
        print("loaded model, ", checkpoint['description'], "at epoch", prevEpoch)
        model.to(device)
    
    for epoch in range(0, n_epochs):
        train_loss = 0.0
        pixelacc = 0
        
        pbar = tqdm(train_loader, total = len(train_loader))
        for X, y in pbar:
            torch.cuda.empty_cache()
            model.train()
            X = X.to(device).float()
            y = y.to(device).float()
            ypred = model(X)
            loss = loss_fn(ypred, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            tr_loss_arr.append(loss.item())
            pixelacctrain.append(pixel_acc(y, ypred))
            pbar.set_postfix({'Epoch':epoch+1+prevEpoch, 
                              'Training Loss': np.mean(tr_loss_arr),
                              'Pixel Acc': np.mean(pixelacctrain)
                             })
            
        with torch.no_grad():
            
            val_loss = 0
            pbar = tqdm(val_loader, total = len(val_loader))
            for X, y in pbar:
                torch.cuda.empty_cache()
                X = X.to(device).float()
                y = y.to(device).float()
                model.eval()
                ypred = model(X)
                
                val_loss_arr.append(loss_fn(ypred, y).item())
                pixelacctest.append(pixe_acc(y, ypred))
                
                pbar.set_postfix({'Epoch':epoch+1+prevEpoch, 
                                  'Validation Loss': np.mean(val_loss_arr),
                                  'Pixel Acc': np.mean(pixelacctest)
                                 })
        
        
        
        checkpoint = {
            'epoch':epoch+1+prevEpoch,
            'description':"add your description",
            'state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'Training Loss': tr_loss_arr,
            'Validation Loss':val_loss_arr, 
            'PixelAcc train':pixelacctrain, 
            'PixelAcc test':pixelacctest
        }
        torch.save(checkpoint, 'checkpoint_fb.pt')
        lr_scheduler.step()
        
    return tr_loss_arr, val_loss_arr, pixelacctrain, pixelacctest


In [None]:
retval = training_loop(1, 
                       optimizer, 
                       lr_scheduler, 
                       model, 
                       loss_fn, 
                       trainLoader, 
                       valLoader)

In [None]:
def segment_hands(pathtest):
    
    if isinstance(pathtest, np.ndarray):
        img = Image.fromarray(pathtest)
    else :
        img = Image.open(pathtest)
    
    preprocess = transforms.Compose([transforms.Resize((380, 640), 2),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    Xtest = preprocess(img)
    
    checkpoint = torch.load('checkpoint_fb.pt')
    model = FootballModel()
    model.load_state_dict(checkpoint['state_dict'])
    with torch.no_grad():
        model.eval()
        if torch.cuda.is_available():
            dev = "cuda:0" 
        else:  
            dev = "cpu"
        device = torch.device(dev)
        model.to(device)
        Xtest = Xtest.to(device).float()
        ytest = model(Xtest.unsqueeze(0).float())
        ypos = ytest[0, 1, :, :].clone().detach().cpu().numpy()
        yneg = ytest[0, 0, :, :].clone().detach().cpu().numpy()
        ytest = ypos >= yneg
    
    mask = ytest.astype('float32')
    kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE,(5,5))
    mask = cv.dilate(mask,kernel,iterations = 2)
    mask = cv.morphologyEx(mask, cv.MORPH_CLOSE, kernel)
    mask = cv.morphologyEx(mask, cv.MORPH_OPEN, kernel)
    return mask

In [None]:
def getcolored_mask(image, mask):
    color_mask = np.zeros_like(image)
    color_mask[:, :, 1] += mask.astype('uint8') * 250
    masked = cv.addWeighted(image, 1.0, color_mask, 1.0, 0.0)
    return masked

In [None]:
i = 0
img_array = []
for filename in glob.glob('foosball_dataset/test_set/*'):
    img = cv.imread(filename)
                          
    frame = cv.resize(img, (640, 380))
    rgb = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
    if i%5 == 0:
        mask = segment_hands(rgb)
        colmask = getcolored_mask(frame, mask)
        fin_img = np.hstack((frame, colmask))
        img_array.append(fin_img)
        
        cv.imshow('color', np.hstack((frame, colmask)))
    i += 1
    key = cv.waitKey(24)
    if key & 0xFF == ord('q'):
        break

In [None]:
frame_size = (640, 380)

In [None]:
fourcc = cv.VideoWriter_fourcc(*'mp4v') 
video = cv.VideoWriter('video.avi', fourcc, 1, frame_size)
for img in img_array:
    video.write(img)

video.release()