In [61]:
import cv2
import datetime
import os
import platform
import time
import torch.cuda 
import xml.etree.ElementTree as ET
import numpy as np
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
#from utils import *
#from model import *
import torch.nn as nn
import torch.nn.functional as F
import gc
import random
import sys
from shutil import rmtree
import matplotlib
import logging
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import mpl_toolkits.axisartist as AA
from mpl_toolkits.axes_grid1 import host_subplot
from sklearn.metrics import confusion_matrix
print('Python version : ', platform.python_version())
print('OpenCV version  : ', cv2.__version__)
print('Torch version : ', torch.__version__)
# Opencv use several cpus by default for simple operation. Using only one allows loading data in parallel much faster
cv2.setNumThreads(0)
print('Nb of threads for OpenCV : ', cv2.getNumThreads())
torch.cuda.is_available()

Python version :  3.8.8
OpenCV version  :  4.5.4
Torch version :  1.10.1+cu113
Nb of threads for OpenCV :  1


True

In [62]:
def reset_training(seed):
    gc.collect()
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def print_and_log(message, log=None):
    print(message)
    if log is not None:
        log.info(message)

def setup_logger(logger_name, log_file, level=logging.INFO):
    l = logging.getLogger(logger_name)
    formatter = logging.Formatter('%(message)s')
    fileHandler = logging.FileHandler(log_file, mode='w')
    fileHandler.setFormatter(formatter)

    l.setLevel(level)
    l.addHandler(fileHandler)
    return l

def close_log(log):
    if log is not None:
        x = list(log.handlers)
        for i in x:
            log.removeHandler(i)
            i.flush()
            i.close()

def make_new_path(path):
    if os.path.exists(path):
        rmtree(path)
        os.mkdir(path)
    else:
        os.mkdir(path)

def make_path(path) :
    if not os.path.exists(path):
        os.mkdir(path)

class ActivePool(object):
    def __init__(self):
        super(ActivePool, self).__init__()
        self.active=[]
        self.running_time=[]
        self.lock=threading.Lock()
    def makeActive(self, name):
        with self.lock:
            self.active.append(name)
    def makeInactive(self, name):
        with self.lock:
            self.active.remove(name)
    def numActive(self):
        with self.lock:
            return len(self.active)
    def __str__(self):
        with self.lock:
            return str(self.active)

In [63]:
def progress_bar(count, total, title, completed=0, log=None):
    terminal_size = get_terminal_size()
    percentage = int(100.0 * count / total)
    length_bar = min([max([3, terminal_size[0] - len(title) - len(str(total)) - len(str(count)) - len(str(percentage)) - 10]),20])
    filled_len = int(length_bar * count / total)
    bar = '█' * filled_len + ' ' * (length_bar - filled_len)
    sys.stdout.write('%s [%s] %s %% (%d/%d)\r' % (title, bar, percentage, count, total))
    sys.stdout.flush()
    if completed:
        sys.stdout.write("\n")
        if log is not None:
            log.info('%s [%s] %s %% (%d/%d)' % (title, bar, percentage, count, total))


def _get_terminal_size_windows():
    try:
        from ctypes import windll, create_string_buffer
        h = windll.kernel32.GetStdHandle(-12)
        csbi = create_string_buffer(22)
        res = windll.kernel32.GetConsoleScreenBufferInfo(h, csbi)
        if res:
            (bufx, bufy, curx, cury, wattr,
             left, top, right, bottom,
             maxx, maxy) = struct.unpack("hhhhHhhhhhh", csbi.raw)
            sizex = right - left + 1
            sizey = bottom - top + 1
            return sizex, sizey
    except:
        pass
    
def get_terminal_size():
    current_os = platform.system()
    tuple_xy = None
    if current_os == 'Windows':
        tuple_xy = _get_terminal_size_windows()
        if tuple_xy is None:
            tuple_xy = _get_terminal_size_tput()
            # needed for window's python in cygwin's xterm!
    if current_os in ['Linux', 'Darwin'] or current_os.startswith('CYGWIN'):
        tuple_xy = _get_terminal_size_linux()
    if tuple_xy is None:
        tuple_xy = (80, 25)      # default value
    return tuple_xy

def _get_terminal_size_tput():
    try:
        cols = int(subprocess.check_call(shlex.split('tput cols')))
        rows = int(subprocess.check_call(shlex.split('tput lines')))
        return (cols, rows)
    except:
        pass

def _get_terminal_size_linux():
    def ioctl_GWINSZ(fd):
        try:
            import fcntl, termios, struct
            cr = struct.unpack('hh', fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234'))
            return cr
        except:
            pass
    cr = ioctl_GWINSZ(0) or ioctl_GWINSZ(1) or ioctl_GWINSZ(2)
    if not cr:
        try:
            fd = os.open(os.ctermid(), os.O_RDONLY)
            cr = ioctl_GWINSZ(fd)
            os.close(fd)
        except:
            pass
    if not cr:
        try:
            cr = (os.environ['LINES'], os.environ['COLUMNS'])
        except:
            return None
    return int(cr[1]), int(cr[0])



In [64]:
def make_train_figure(loss_train, loss_val, acc_val, acc_train, path_to_save):

    host = host_subplot(111, axes_class=AA.Axes)
    par = host.twinx()

    host.set_xlabel("Epochs")
    host.set_ylabel("Loss")
    par.set_ylabel("Accuracy")

    par.axis["right"].toggle(all=True)

    epochs = [i for i in range(1, len(loss_val)+1)]

    host.set_xlim(1, len(epochs))
    host.set_ylim(0, np.max([np.max(loss_train), np.max(loss_val)]))
    par.set_ylim(0, 1)

    max_acc = max(acc_val)
    max_acc_idx = epochs[acc_val.index(max_acc)]
    host.set_title("Max Validation Accuracy: %.1f%% at iteration %d" % (max_acc*100, max_acc_idx))

    host.plot(epochs, loss_train, label="Train loss", linewidth=1.5)
    host.plot(epochs, loss_val, label="Validation loss", linewidth=1.5)
    par.plot(epochs, acc_val, label="Validation Accuracy", linewidth=1.5)
    par.plot(epochs, acc_train, label="Train Accuracy", linewidth=1.5)

    host.legend(loc='lower right', ncol=1, fancybox=False, shadow=True)

    plt.savefig(path_to_save)
    plt.close('all')
    return True
    

In [65]:
def plot_confusion_matrix(cm, classes, path, normalize=False, cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting 'normalize=True'.
    """
    
    acc = np.mean(np.array([cm[i,i] for i in range(len(cm))]).sum()/cm.sum()) * 100

    if normalize:
        cm_txt = cm.astype('float') / np.array([max(tmp,1) for tmp in cm.sum(axis=1)[:, np.newaxis]]).astype('float')
    else:
        cm_txt = cm
    
    cm = cm.astype('float') / np.array([max(tmp,1) for tmp in cm.sum(axis=1)[:, np.newaxis]]).astype('float')
    acc_2 = np.array([cm[i,i] for i in range(len(cm))])

    title = 'Accuracy of %.1f%% ($\\mu$ = %.1f with $\\sigma$ = %.1f)' % (acc, np.mean(acc_2)*100, np.std(acc_2)*100)
    plt.subplots(figsize=(12,12))

    plt.imshow(cm.astype('float'), interpolation='nearest', cmap=cmap, vmin=0, vmax=1)
    plt.title(title, fontsize=18)
    plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90, fontsize=14)
    plt.yticks(tick_marks, classes, fontsize=14)

    fmt = '.2g' if normalize else 'd'
    thresh = .5
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, format(round(cm_txt[i, j]*100,2), fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, format(round(cm_txt[i, j],2), fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label', fontsize=14)
    plt.xlabel('Predicted label', fontsize=14)
    plt.tight_layout()
    plt.savefig(path)
    plt.close('all')

In [66]:
def frame_extractor(video_path, width, save_path):
    # Load Video
    cap = cv2.VideoCapture(video_path)
    length_video = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_number = 0

    # Check if video uploaded
    if not cap.isOpened():
        sys.exit("Unable to open the video, check the path.\n")

    while frame_number < length_video:

        # Load video
        _, rgb = cap.read()

        # Check if load Properly
        if _ == 1:
            # Resizing and Save
            rgb = cv2.resize(rgb, (width, rgb.shape[0]*width//rgb.shape[1]))
            cv2.imwrite(os.path.join(save_path, '%08d.png' % frame_number), rgb)
            frame_number+=1
    cap.release()

In [67]:
def flatten_features(x):
    size = x.size()[1:]  # all dimensions except the batch dimension
    num_features = 1
    for s in size:
        num_features *= s
    return num_features

In [68]:
class NetSimpleBranch(nn.Module):
    def __init__(self, size_data, n_classes, channels=3):
        super(NetSimpleBranch, self).__init__()

        ####################
        ####### First ######
        ####################
        self.conv1 = nn.Conv3d(channels, 30, (3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.pool1 = nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2))
        size_data //= 2

        ####################
        ###### Second ######
        ####################
        self.conv2 = nn.Conv3d(30, 60, (3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) 
        self.pool2 = nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2))
        size_data //= 2

        ####################
        ####### Third ######
        ####################
        self.conv3 = nn.Conv3d(60, 80, (3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) 
        self.pool3 = nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2))
        size_data //= 2

        ####################
        ####### Last #######
        ####################
        self.linear1 = nn.Linear(80*size_data[0]*size_data[1]*size_data[2], 500)
        self.relu = nn.ReLU()

        # Fusion
        self.linear2 = nn.Linear(500, n_classes)
        self.final = nn.Softmax(1)

    def forward(self, data):

        ####################
        ####### First ######
        ####################
        data = self.pool1(F.relu(self.conv1(data)))

        ####################
        ###### Second ######
        ####################
        data = self.pool2(F.relu(self.conv2(data)))

        ####################
        ####### Third ######
        ####################
        data = self.pool3(F.relu(self.conv3(data)))


        ####################
        ####### Last #######
        ####################
        data = data.view(-1, flatten_features(data))
        data = self.relu(self.linear1(data))

        data = self.linear2(data)
        label = self.final(data)

        return label
        print(model.summary)

In [69]:
class my_variables():
    def __init__(self, task_path, size_data=[98, 120, 120], cuda=True, batch_size=15, workers=6, epochs=1000, lr=0.0001, nesterov=True, weight_decay=0.005, momentum=0.5):
        self.cuda = cuda
        self.workers = workers
        self.batch_size = batch_size
        self.size_data = np.array(size_data)
        self.epochs = epochs
        self.lr = lr
        self.nesterov = nesterov
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.model_name = os.path.join(task_path, 'MediaEval21_%s' % (datetime.datetime.now().strftime('%d-%m-%Y_%H-%M')))
        make_path(task_path)
        make_path(self.model_name)
        if cuda:
            self.dtype = torch.cuda.FloatTensor
            # os.environ[ 'CUDA_VISIBLE_DEVICES' ] = '0'
        else:
            self.dtype = torch.FloatTensor
        self.log = setup_logger('model_log', os.path.join(self.model_name, 'model_log.log'))

In [70]:
''' My_dataset class which uses My_stroke class to be used in the data loader'''
class My_dataset(Dataset):
    def __init__(self, dataset_list, size_data):
        self.dataset_list = dataset_list
        self.size_data = size_data

    def __len__(self):
        return len(self.dataset_list)

    def __getitem__(self, idx):
        rgb = get_data(self.dataset_list[idx].video_path, self.dataset_list[idx].begin, self.size_data)
        sample = {'rgb': torch.FloatTensor(rgb), 'label' : self.dataset_list[idx].move, 'my_stroke' : {'video_path':self.dataset_list[idx].video_path, 'begin':self.dataset_list[idx].begin, 'end':self.dataset_list[idx].end}}
        return sample

''' My_stroke class used for encoding the annotations'''
class My_stroke:
    def __init__(self, video_path, begin, end, move):
        self.video_path = video_path
        self.begin = begin
        self.end = end
        self.move = move

    def my_print(self, log=None):
        print_and_log('Video : %s\tbegin : %d\tEnd : %d\tClass : %s' % (self.video_path, self.begin, self.end, self.move), log=log)

''' Get annotations from xml files located in one folder and produce a list of My_stroke'''
def get_annotations(xml_path, data_folder, list_of_strokes=None):
    xml_list = [os.path.join(xml_path, f) for f in os.listdir(xml_path) if os.path.isfile(os.path.join(xml_path, f)) and f.split('.')[-1]=='xml']
    strokes_list = []
    for xml_file in xml_list:
        tree = ET.parse(xml_file)
        root = tree.getroot()
        video_path = os.path.join(data_folder, xml_file.split('/')[-1].split('.')[0])
        for action in root:
            if list_of_strokes is None:
                strokes_list.append(My_stroke(video_path, int(action.get('begin')), int(action.get('end')), 1))
            else:
                strokes_list.append(My_stroke(video_path, int(action.get('begin')), int(action.get('end')), list_of_strokes.index(action.get('move'))))
        # Case of the test set in segmentation task - build proposals of size 150
        if len(root)==0: 
            for begin in range(0,len(os.listdir(video_path))-150,150):
                strokes_list.append(My_stroke(video_path, begin, begin+150, 0))
    return strokes_list

'''Infer Negative Samples from annotation betwen strokes when there are more than length_min frames'''
def build_negative_strokes(stroke_list, length_min=200):
    video_path = 'tmp'
    for stroke in stroke_list.copy():
        if stroke.video_path != video_path:
            video_path = stroke.video_path
            begin_negative = 0
        end_negative = stroke.begin
        for end in range(begin_negative+length_min, end_negative, length_min):
            stroke_list.append(My_stroke(video_path, end-length_min, end, 0))
        begin_negative = stroke.end

''' Get the rgb frames from the annotations'''
def get_data(data_path, begin, size_data):
    rgb_data = []
    for frame_number in range(begin, begin + size_data[0]):
        try:
            rgb = cv2.imread(os.path.join(data_path, '%08d.png' % frame_number))
            rgb = cv2.resize(rgb, (size_data[1], size_data[2])).astype(float) / 255
        except:
            raise ValueError('Problem with %s begin %d size %d' % (os.path.join(data_path, '%08d.png' % frame_number), begin, size_data[0]))
        rgb_data.append(cv2.split(rgb))

    rgb_data = np.transpose(rgb_data, (1, 0, 2, 3))
    return rgb_data

In [71]:
def make_architecture(args, output_size):
    print_and_log('Make Model', log=args.log)
    model = NetSimpleBranch(args.size_data.copy(), output_size)
    ## Use GPU
    if args.cuda:
        model.cuda()
    return model

In [72]:
''' Training is split in train epoch and validation epoch and produce a plot'''
def train_model(model, args, train_loader, valid_loader):
    criterion = nn.CrossEntropyLoss(reduction='sum')
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)

    begin_time = time.time()
    print_and_log('\nTraining...', log=args.log)

    # For plot
    loss_train = []
    loss_val = []
    acc_val = []
    acc_train = []

    for epoch in range(args.epochs):
        # Train and validation step and save loss and acc for plot
        loss_train_, acc_train_ = train_epoch(epoch, args, model, train_loader, optimizer, criterion)
        loss_val_, acc_val_ = validation_epoch(epoch, args, model, valid_loader, criterion)

        loss_train.append(loss_train_)
        acc_train.append(acc_train_)
        loss_val.append(loss_val_)
        acc_val.append(acc_val_)
    print_and_log('Max validation accuracy of %.2f done in %ds' % (max(acc_val), int(time.time()-begin_time)), log=args.log)
    make_train_figure(loss_train, loss_val, acc_val, acc_train, os.path.join(args.model_name, 'Train.png'))
    return 1

''' Update of the model in one epoch'''
def train_epoch(epoch, args, model, data_loader, optimizer, criterion):
    model.train()
    pid = os.getpid()
    N = len(data_loader.dataset)
    begin_time = time.time()
    aLoss = 0
    Acc = 0

    for batch_idx, batch in enumerate(data_loader):
        # Get batch tensor
        rgb, label = batch['rgb'], batch['label']

        rgb = Variable(rgb.type(args.dtype))
        label = Variable(label.type(args.dtype).long())

        optimizer.zero_grad()
        output = model(rgb)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        aLoss += loss.item()
        Acc += output.data.max(1)[1].eq(label.data).cpu().sum().numpy()
        progress_bar((batch_idx + 1) * args.batch_size, N, '%d Training - Epoch : %d - Batch Loss = %.5g' % (pid, epoch, loss.item()))

    aLoss /= N
    progress_bar(N, N, 'Train - Epoch %d - Loss = %.5g - Accuracy = %.3g (%d/%d) - Time = %ds' % (epoch, aLoss, Acc/N, Acc, N, time.time() - begin_time), 1, log=args.log)
    return aLoss, Acc/N


'''Validation of the model in one epoch'''
def validation_epoch(epoch, args, model, data_loader, criterion):
    with torch.no_grad():
        begin_time = time.time()
        pid = os.getpid()
        N = len(data_loader.dataset)
        _loss = 0
        _acc = 0

        for batch_idx, batch in enumerate(data_loader):
            progress_bar(batch_idx*args.batch_size, N, '%d - Validation' % (pid))
            rgb, label = batch['rgb'], batch['label']
            rgb = Variable(rgb.type(args.dtype))
            label = Variable(label.type(args.dtype).long())
            output = model(rgb)
            _loss += criterion(output, label).item()
            output_indexes = output.data.max(1)[1]
            _acc += output.data.max(1)[1].eq(label.data).cpu().sum().numpy()

        _loss /= N
        progress_bar(N, N, 'Validation - Loss = %.5g - Accuracy = %.3g (%d/%d) - Time = %ds' % (_loss, _acc/N, _acc, N, time.time() - begin_time), 1, log=args.log)
        return _loss, _acc/N

In [73]:
'''Store data for xml files from the list of stroke with predicted class - for detection it is saved when index predicted to 1'''
def store_xml_data(my_stroke_list, predicted, xml_files, list_of_strokes=None):
    for video_path, begin, end, prediction_index in zip(my_stroke_list['video_path'], my_stroke_list['begin'].tolist(), my_stroke_list['end'].tolist(), predicted):
        video_name = video_path.split('/')[-1]
        if video_name not in xml_files:
            xml_files[video_name] = ET.Element('video')
        if list_of_strokes is None:
            if prediction_index:
                stroke_xml = ET.SubElement(xml_files[video_name], 'action')
                stroke_xml.set('begin', str(begin))
                stroke_xml.set('end', str(end))
        else:
            stroke_xml = ET.SubElement(xml_files[video_name], 'action')
            stroke_xml.set('begin', str(begin))
            stroke_xml.set('end', str(end))
            stroke_xml.set('move', list_of_strokes[prediction_index])

'''Save the predictions in xml files'''
def save_xml_data(xml_files, path_xml_save):
    for video_name in xml_files:
        xml_file = open('%s.xml' % os.path.join(path_xml_save, video_name), 'wb')
        xml_file.write(ET.tostring(xml_files[video_name]))
        xml_file.close()

'''Inference on test set'''
def test_model(model, args, data_loader, list_of_strokes=None):
    with torch.no_grad():
        xml_files = {}
        path_xml_save = os.path.join(args.model_name, 'xml_test')
        make_path(path_xml_save)
        N = len(data_loader.dataset)
        
        for batch_idx, batch in enumerate(data_loader):
            # Get batch tensor
            rgb, my_stroke_list = batch['rgb'], batch['my_stroke']
            progress_bar(args.batch_size*batch_idx, N, 'Testing')

            rgb = Variable(rgb.type(args.dtype))
            output = model(rgb)
            _, predicted = torch.max(output.detach(), 1)
            store_xml_data(my_stroke_list, predicted, xml_files, list_of_strokes)

        progress_bar(N, N, 'Test done', 1, log=args.log)
        save_xml_data(xml_files, path_xml_save)


In [74]:
'''Set up the environment and extract data'''
def make_work_tree(main_folder, source_folder, frame_width=320, extract=False):
    data_path = os.path.join(main_folder, 'data')
    video_folder = os.path.join(source_folder, 'videos')
    detection_path = os.path.join(source_folder,'detectionTask')
    classification_path = os.path.join(source_folder,'classificationTask')
    if extract:
        make_path(main_folder)
        make_path(data_path)
        video_list = [_file for _file in os.listdir(video_folder) if _file[-4:]=='.mp4' and os.path.isfile(os.path.join(video_folder, _file))]
        for idx, video in enumerate(video_list):
            save_frame_path = os.path.join(data_path, video[:-4])
            make_path(save_frame_path)
            progress_bar(idx, len(video_list), 'Frame extraction')
            frame_extractor(os.path.join(video_folder, video), frame_width, save_frame_path)
        progress_bar(len(video_list), len(video_list), 'Frame extraction done', 1)
    return main_folder, data_path, detection_path, classification_path

''' According to overview paper'''
def get_list_of_strokes():
    list_of_strokes = ['Serve Forehand Backspin',
                   'Serve Forehand Loop',
                   'Serve Forehand Sidespin',
                   'Serve Forehand Topspin',

                   'Serve Backhand Backspin',
                   'Serve Backhand Loop',
                   'Serve Backhand Sidespin',
                   'Serve Backhand Topspin',

                   'Offensive Forehand Hit',
                   'Offensive Forehand Loop',
                   'Offensive Forehand Flip',

                   'Offensive Backhand Hit',
                   'Offensive Backhand Loop',
                   'Offensive Backhand Flip',

                   'Defensive Forehand Push',
                   'Defensive Forehand Block',
                   'Defensive Forehand Backspin',

                   'Defensive Backhand Push',
                   'Defensive Backhand Block',
                   'Defensive Backhand Backspin',
                   'Unknown']
    return list_of_strokes

'''Get the split of annotation and construct negative samples fro; it if in dectetion task'''
def get_lists_annotations(task_path, data_path, list_of_strokes=None):
    train_strokes = get_annotations(os.path.join(task_path, 'train'), data_path, list_of_strokes)
    valid_strokes = get_annotations(os.path.join(task_path, 'valid'), data_path, list_of_strokes)
    test_strokes = get_annotations(os.path.join(task_path, 'test'), data_path, list_of_strokes)
    if list_of_strokes is None:
        build_negative_strokes(train_strokes)
        build_negative_strokes(valid_strokes)
        build_negative_strokes(test_strokes)
    return train_strokes, valid_strokes, test_strokes

''' Build dataloader from list of strokes'''
def get_data_loaders(train_strokes, valid_strokes, test_strokes, size_data, batch_size, workers):
    # Build Dataset
    train_set = My_dataset(train_strokes, size_data)
    valid_set = My_dataset(valid_strokes, size_data)
    test_set = My_dataset(test_strokes, size_data)

    # Loaders of the Datasets
    train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=0, shuffle=True)
    valid_loader = DataLoader(valid_set, batch_size=batch_size, num_workers=0, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=0)
    return train_loader, valid_loader, test_loader

'''Classification task'''
def classification_task(main_folder, data_path, task_path):
    print('\nClassification Task')
    # Initial list
    reset_training(1)
    list_of_strokes = get_list_of_strokes()

    # Split
    train_strokes, valid_strokes, test_strokes = get_lists_annotations(task_path, data_path, list_of_strokes)
    
    # Model variables
    args = my_variables('classificationTask')
    
    ## Architecture with the output of the lenght of possible classes - (Unknown not counted)
    model = make_architecture(args, len(list_of_strokes)-1)

    # Loaders
    train_loader, valid_loader, test_loader = get_data_loaders(train_strokes, valid_strokes, test_strokes, args.size_data, args.batch_size, args.workers)

    # Training process
    train_model(model, args, train_loader, valid_loader)
    
    # Test process 
    test_model(model, args, test_loader, list_of_strokes)
    return 1

In [75]:
if __name__ == "__main__":
    # MediaEval Task source folder
    source_folder = 'C:/Users/s222237/OneDrive - University of Suffolk/data'
    
    # Prepare tree and data - To call only once with extract set to True
    main_folder, data_path, detection_path, classification_path = make_work_tree('.', source_folder, extract=False)

    # Tasks
    classification_task(main_folder, data_path, classification_path)

    print_and_log('All Done')


Classification Task
Make Model

Training...


ValueError: Problem with .\data\data\classificationTask\train\4657106722\00054120.png begin 54120 size 98