# Libraries


In [None]:
import tensorflow as tf
import pickle
import torch
from tqdm import tqdm
import time
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import json
from collections import OrderedDict
import numpy as np
import pandas as pd
import itertools
import Levenshtein as Lev
from sklearn import metrics
import torch.nn as nn
import os
from distutils.dir_util import mkpath
from google.colab import drive

In [None]:
drive.mount('/content/drive')

In [None]:
cd '/content/drive/My Drive/Advance Project - 523 4/Code/Eye Tracking Tool/gazeNet-Colab/'

In [None]:
!pip install python-levenshtein

In [None]:



def round_up_to_odd(f, min_val = 3):
    w = np.int32(np.ceil(f) // 2 * 2 + 1)
    w = min_val if w < min_val else w
    return w


def rolling_window(a, window):
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

def BoxMuller_gaussian(u1,u2):
  z1 = np.sqrt(-2*np.log(u1))*np.cos(2*np.pi*u2)
  z2 = np.sqrt(-2*np.log(u1))*np.sin(2*np.pi*u2)
  return z1,z2

def convertToOneHot(vector, num_classes=None):
    assert isinstance(vector, np.ndarray)
    assert len(vector) > 0

    if num_classes is None:
        num_classes = np.max(vector)+1
    else:
        assert num_classes > 0
        assert num_classes >= np.max(vector)

    result = np.zeros(shape=(len(vector), num_classes))
    result[np.arange(len(vector)), vector] = 1
    return result.astype(int)

class Config(object):
    def __init__(self, param_file):
        self.param_file = param_file
        self.read_params()

    class bcolors(object):
        HEADER = '\033[95m'
        OKBLUE = '\033[94m'
        OKGREEN = '\033[92m'
        WARNING = '\033[93m'
        FAIL = '\033[91m'
        ENDC = '\033[0m'
        BOLD = '\033[1m'
        UNDERLINE = '\033[4m'

    def read_params(self, current_params = None):
        with open(self.param_file, 'r') as f:
            self.params = json.load(f, object_pairs_hook=OrderedDict)
        if not(current_params is None) and not (current_params == self.params):
            print ("TRAINING PARAMETERS CHANGED")
            for k, p in current_params.iteritems():
                if not(p == self.params[k]):
                    print (self.bcolors.WARNING + \
                          "%s: %s --> %s" % (k, p, self.params[k]) + \
                          self.bcolors.ENDC)
            return True
    def save_params(self, params=None):
        if not(params):
            params = self.params
        else:
            self.params = params

        with open(self.param_file, 'w') as f:
            json.dump(params, f, indent=4)

def human_format(num, suffixes=['', 'K', 'M', 'G', 'T', 'P']):
    m = sum([abs(num/1000.0**x) >= 1 for x in range(1, len(suffixes))])
    val = num/1000.**m
    return '%.3f%s' % (val, suffixes[m])

In [None]:
class EventParser(object):
    def __init__(self, config):
        super(EventParser, self).__init__()
        self.config = config


    def parse_data(self, sample):
        config = self.config
        augment = config['augment']
        rms_noise_levels = np.arange(*config["augment_noise"])

        inpt_dir = ['x', 'y']

        gaze_x = np.copy(sample[inpt_dir[0]])
        gaze_y = np.copy(sample[inpt_dir[1]])

        if augment:
            u1, u2 = np.random.uniform(0,1, (2, len(sample)))
            noise_x, noise_y = BoxMuller_gaussian(u1,u2)
            rms_noise_level = np.random.choice(rms_noise_levels)
            noise_x*=rms_noise_level/2
            noise_y*=rms_noise_level/2
            #rms = np.sqrt(np.mean(np.hypot(np.diff(noise_x), np.diff(noise_y))**2))
            gaze_x+=noise_x
            gaze_y+=noise_y

        inpt_x, inpt_y = [np.diff(gaze_x),
                          np.diff(gaze_y)]

        X = [(_coords) for _coords in zip(inpt_x, inpt_y)]
        X = np.array(X, dtype=np.float32)

        return X

In [None]:
class EMDataset(Dataset, EventParser):
    def __init__(self, config, gaze_data):

        split_seqs = config['split_seqs']
        #mode = config['mode']

        #input is in fact diff(input), therefore we want +1 sample
        seq_len = config['seq_len']+1
        #seq_step = seq_len/2 if mode == 'train' else seq_len
        seq_step = seq_len

        data = []
        #seqid = -1
        for d in gaze_data: #iterates over files
            dd = np.split(d, np.where(np.diff(d['status'].astype(np.int0)) != 0)[0]+1)
            dd = [_d for _d in dd if (_d['status'].all() and not(len(_d) < seq_len))]

            for seq in dd: #iterates over chunks of valid data
                #seqid +=1
                if split_seqs and not(len(seq) < seq_len):
                    seqs = [seq[pos:pos + seq_len] if (pos + seq_len) < len(seq) else
                            seq[len(seq)-seq_len:len(seq)] for pos in range(0, len(seq), seq_step)]
                else:
                    seqs = [seq]

                data.extend(seqs)

        self.data = data
        self.size = len(data)
        self.config = config

        super(EMDataset, self).__init__(config)

    def __getitem__(self, index):
        sample = self.data[index]
        gaze_data = self.parse_data(sample)
        evt = self.parse_evt(sample['evt'])

        return torch.FloatTensor(gaze_data.T), evt, ()

    def parse_evt(self, evt):
        return evt[1:]-1

    def __len__(self):
        return self.size

In [None]:
def _collate_fn(batch):
    def func(p):
        return p[0].size(1)

    #return batch
    longest_sample = max(batch, key=func)[0]
    freq_size = longest_sample.size(0)
    minibatch_size = len(batch)
    max_seqlength = longest_sample.size(1)
    inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength)
    input_percentages = torch.FloatTensor(minibatch_size)
    target_sizes = [] 
    targets = []

    for x in range(minibatch_size):
        sample = batch[x]

        tensor, target, (_) = sample
        seq_length = tensor.size(1)
        inputs[x][0].narrow(1, 0, seq_length).copy_(tensor)
        input_percentages[x] = seq_length / float(max_seqlength)
        target_sizes.append(len(target))
        targets.extend(target.tolist())
    targets = torch.LongTensor(targets)
    return inputs, targets, input_percentages, target_sizes, (_)

In [None]:
class GazeDataLoader(DataLoader):
    def __init__(self, *args, **kwargs):

        # seed = kwargs.pop('seed', 220617)
        super(GazeDataLoader, self).__init__(*args, **kwargs)
        np.random.seed(seed)
        self.collate_fn = _collate_fn
        #self.sampler = RandomSampler(*args)


In [None]:
def aggr_events(events_raw):
    events_aggr = []
    s = 0
    for bit, group in itertools.groupby(events_raw):
        event_length = len(list(group))
        e = s+event_length
        events_aggr.append([s, e, bit])
        s = e
    return events_aggr

In [None]:
class ETData():
    #Data types and constants
    dtype = np.dtype([
        ('t', np.float64),
        ('x', np.float32),
        ('y', np.float32),
        ('status', np.bool),
        ('evt', np.uint8)
    ])
    evt_color_map = dict({
        0: 'gray',  #0. Undefined
        1: 'b',     #1. Fixation
        2: 'r',     #2. Saccade
        3: 'y',     #3. Post-saccadic oscillation
        4: 'm',     #4. Smooth pursuit
        5: 'k',     #5. Blink
        9: 'k',     #9. Other
    })

    def __init__(self):
        self.data = np.array([], dtype=ETData.dtype)
        self.fs = None
        self.evt = None

    def load(self, fpath, **kwargs):
        if not('source' in kwargs):
            try:
                self.data = np.load(fpath)
            except:
                print("ERROR loading %s" % fpath)
        else:
            if kwargs['source']=='etdata':
                self.data = np.load(fpath)

            if kwargs['source']=='array':
                if not fpath.dtype == ETData.dtype:
                    print ("Error. Data types do not match")
                    return False
                self.data = fpath

            if kwargs['source']=='np_array':
                self.data = np.core.records.fromarrays(fpath.T,
                                                       dtype=ETData.dtype)

            if callable(kwargs['source']):
                self.data = kwargs['source'](fpath, ETData.dtype)

        #estimate sampling rate
        self.fs = float(self.find_nearest_fs(self.data['t']))
        self.evt = None
        return self.data

    def save(self, spath):  
        np.save(spath, self.data)

    def find_nearest_fs(self, t):
        fs = np.array([2000, 1250, 1000, 600, 500,  #high end
                       300, 250, 240, 200,          #middle end
                       120, 75, 60, 50, 30, 25])    #low end
        ##debug
        #if (np.diff(t) == 0).any():
        #    stop
        t = np.median(1/np.diff(t))
        print("ETDATA-----------------------------",t)
        return fs.flat[np.abs(fs - t).argmin()]

    def calc_evt(self, fast=False):
        '''Calculated event data
        '''
        evt_compact = aggr_events(self.data['evt'])
        evt = pd.DataFrame(evt_compact,
                           columns = ['s', 'e', 'evt'])
        evt['dur_s'] = np.diff(evt[['s', 'e']], axis=1).squeeze()
        evt['dur'] = evt['dur_s']/self.fs

        if not(fast):
            evt['posx_s'], evt['posx_e'], evt['posy_s'], evt['posy_e'],\
            evt['posx_mean'], evt['posy_mean'], evt['posx_med'], evt['posy_med'],\
            evt['pv'], evt['pv_index'], evt['rms'], evt['std']   = \
               zip(*map(lambda x: calc_event_data(self, x), evt_compact))
            evt['ampl_x'] = np.diff(evt[['posx_s', 'posx_e']])
            evt['ampl_y'] = np.diff(evt[['posy_s', 'posy_e']])
            evt['ampl'] = np.hypot(evt['ampl_x'], evt['ampl_y'])
        #TODO:
        #   calculate fix-to-fix saccade amplitude
        self.evt = evt
        return self.evt

    def plot(self, spath = None, save=False, show=True, title=None):
        '''Plots trial
        '''
        if show:
            plt.ion()
        else:
            plt.ioff()

        fig = plt.figure(figsize=(10,6))
        ax00 = plt.subplot2grid((2, 2), (0, 0))
        ax10 = plt.subplot2grid((2, 2), (1, 0), sharex=ax00)
        ax01 = plt.subplot2grid((2, 2), (0, 1), rowspan=2)

        ax00.plot(self.data['t'], self.data['x'], '-')
        ax10.plot(self.data['t'], self.data['y'], '-')
        ax01.plot(self.data['x'], self.data['y'], '-')
        for e, c in ETData.evt_color_map.iteritems():
            mask = self.data['evt'] == e
            ax00.plot(self.data['t'][mask], self.data['x'][mask], '.', color = c)
            ax10.plot(self.data['t'][mask], self.data['y'][mask], '.', color = c)
            ax01.plot(self.data['x'][mask], self.data['y'][mask], '.', color = c)

        etdata_extent = np.nanmax([np.abs(self.data['x']), np.abs(self.data['y'])])+1

        ax00.axis([self.data['t'].min(), self.data['t'].max(), -etdata_extent, etdata_extent])
        ax10.axis([self.data['t'].min(), self.data['t'].max(), -etdata_extent, etdata_extent])
        ax01.axis([-etdata_extent, etdata_extent, -etdata_extent, etdata_extent])

In [None]:
def calc_k(gt, pr):
    k = 1. if (gt == pr).all() else metrics.cohen_kappa_score(gt, pr)
    return k

In [None]:
def eval_evt(etdata_gt, etdata_pr, n_events):

    t = time.time()
    if etdata_gt.evt is None:
        etdata_gt.calc_evt(fast=True)
    if etdata_pr.evt is None:
        etdata_pr.calc_evt(fast=True)

    #levenshtein distance

    evt_gt = etdata_gt.evt['evt']
    evt_gt = evt_gt[~(evt_gt==0)]
    evt_pr = etdata_pr.evt['evt']
    evt_pr = evt_pr[~(evt_pr==0)]
    wer = Lev.distance(''.join(map(str, evt_gt)),
                       ''.join(map(str, evt_pr)))/\
                       float(len(evt_gt))

    _cer = map(lambda _a, _b: Lev.distance(_a, _b),
               ''.join(map(str, etdata_gt.data['evt'])).split('0'),
               ''.join(map(str, etdata_pr.data['evt'])).split('0'))
    mask=etdata_gt.data['evt']==0
    evt_len = float(sum(~mask))
    cer = sum(_cer)/evt_len


    #sample level K
    t = time.time()
    evts_gt_oh = convertToOneHot(etdata_gt.data['evt'], n_events)
    evts_pr_oh = convertToOneHot(etdata_pr.data['evt'], n_events)
    ks = [calc_k(evts_gt_oh[:,i], evts_pr_oh[:,i]) for i in range(1, n_events)]

    evt_gt = etdata_gt.data['evt']
    evt_gt = evt_gt[~(evt_gt==0)]
    evt_pr = etdata_pr.data['evt']
    evt_pr = evt_pr[~(evt_pr==0)]
    ks_all = metrics.cohen_kappa_score(evt_gt, evt_pr)

    ks.extend([ks_all])

    #event level K and F1
    try:
        t = time.time()

        ke_ = []
        f1e_ = []
        for evt in range(1, 4):
            #evt=1
            _etdata_gt = copy.deepcopy(etdata_gt)
            mask_ext = _etdata_gt.data['evt']==0
            mask = _etdata_gt.data['evt']==evt
            _etdata_gt.data['evt'][mask]=1
            _etdata_gt.data['evt'][~mask]=0
            _etdata_gt.data['evt'][mask_ext]=255
            _etdata_gt.calc_evt(fast=True)

            _etdata_pr = copy.deepcopy(etdata_pr)
            mask_ext = _etdata_pr.data['evt']==0
            mask = _etdata_pr.data['evt']==evt
            _etdata_pr.data['evt'][mask]=1
            _etdata_pr.data['evt'][~mask]=0
            _etdata_pr.data['evt'][mask_ext]=255
            _etdata_pr.calc_evt(fast=True)

            evt_overlap, evt_gt, evt_pr = calc_KE(_etdata_gt, _etdata_pr)

            mask = (evt_gt==255) & (evt_pr==255)
            evt_gt = evt_gt[~mask]
            evt_pr = evt_pr[~mask]
            ke_.append(calc_k(evt_gt, evt_pr))
            f1e_.append(calc_f1(evt_gt, evt_pr))


        evt_overlap, evt_gt, evt_pr = calc_KE(etdata_gt, etdata_pr)
        mask = (evt_gt==0) & (evt_pr==0)
        evt_gt = evt_gt[~mask]
        evt_pr = evt_pr[~mask]
        #print ('[overlap], dur %.2f' % (time.time()-t))
        evt_gt_oh = convertToOneHot(evt_gt, n_events)
        evt_pr_oh = convertToOneHot(evt_pr, n_events)
        ke = [calc_k(evt_gt_oh[:,i], evt_pr_oh[:,i]) for i in range(1, n_events)]
        f1e = [calc_f1(evt_gt_oh[:,i], evt_pr_oh[:,i]) for i in range(1, n_events)]

        ke_all = metrics.cohen_kappa_score(evt_gt, evt_pr)
        f1_all = metrics.f1_score(evt_gt, evt_pr, average='weighted')
        ke.extend([ke_all])
        ke_.extend([ke_all])
        f1e.extend([f1_all])
        f1e_.extend([f1_all])
        #print ('[KE], dur %.2f' % (time.time()-t))
    except:
        #TODO: Debug
        ks = [0.,]*(n_events+1)
        ke = [0.,]*(n_events+1)
        f1e = [0.,]*(n_events+1)

    return wer, cer, ke_, ks, f1e_, (evt_overlap, evt_gt, evt_pr)

In [None]:
def run_infer(model, n_samples, data_loader, **kwargs):
    fs = 500.
    cuda = False if not("cuda" in kwargs) else kwargs["cuda"]
    use_tqdm = False if not("use_tqdm" in kwargs) else kwargs["use_tqdm"]
    perform_eval = True if not("eval" in kwargs) else kwargs["eval"]
    #save_dir = None if not(kwargs.has_key("save_dir")) else kwargs["save_dir"]

    etdata_pr = ETData()
    etdata_gt = ETData()
    _etdata_pr = []
    _etdata_gt = []
    _pr_raw=[]

    sample_accum = 0
    t = time.time()
    iterator = tqdm(data_loader) if use_tqdm else data_loader
    with torch.no_grad():
        for data in iterator:
            inputs, targets, input_percentages, target_sizes, aux = data

            #do forward pass
            # inputs = Variable(inputs, volatile=True).contiguous()
            if cuda:
                inputs = inputs.cuda()
            y = model(inputs)
            seq_length = y.size(1)
            sizes = Variable(input_percentages.mul(int(seq_length)).int())

            if cuda:
                inputs = inputs.cpu()
                y = y.cpu()
                sizes = sizes.cpu()

                targets = targets.cpu()

            #decode output
            outputs_split = [_y[:_l] for _y, _l in zip(y.data, target_sizes)]

            events_decoded = [torch.max(_o, 1)[1].numpy().flatten() for _o in outputs_split]
            events_target= np.array_split(targets.numpy(), np.cumsum(sizes.data.numpy())[:-1])

            trials = [np.cumsum(_y[0, :, :_l], axis=1).T for _y, _l in zip(inputs.data.numpy(), target_sizes)]

            for ind, (gt, pr, pr_raw, tr) in enumerate(zip(events_target, events_decoded, outputs_split, trials)):
                #TODO:
                #check why sizes do not match sometimes

                minl = min(len(gt), len(pr))
                gt = gt[:minl]
                pr = pr[:minl]
                _pr_raw.append(pr_raw.numpy())
                #pr = np.hstack((pr[0], pr[:-1]))
                _etdata_pr.extend(zip(np.arange(len(gt))/fs,
                              tr[:,0],
                              tr[:,1],
                              itertools.repeat(True),
                              pr+1
                           ))
                _etdata_pr.append((0, )*5)
                _etdata_gt.extend(zip(np.arange(len(gt))/fs,
                              tr[:,0],
                              tr[:,1],
                              itertools.repeat(True),
                              gt+1
                           ))
                _etdata_gt.append((0, )*5)

                sample_accum+=1

            if sample_accum >= n_samples:
                break
        print ('[FP], n_samples: %d, dur: %.2f' % (sample_accum, time.time()-t))

    if perform_eval:
        #run evaluation
        etdata_pr.load(np.array(_etdata_pr), **{'source':'np_array'})
        etdata_gt.load(np.array(_etdata_gt), **{'source':'np_array'})
        wer, cer, ke, ks, _, (evt_overlap, _, _) = eval_evt(etdata_gt, etdata_pr, 4)
        return wer, cer, ke, ks, (_etdata_gt, _etdata_pr, _pr_raw)
    else:
        return _etdata_gt, _etdata_pr, _pr_raw

In [None]:
def checkpoint(model, step=None, epoch=None):
    package = {
        'epoch': epoch if epoch else 'N/A',
        'step': step if step else 'N/A',
        'state_dict': model.state_dict(),
    }
    return package

In [None]:
def load(model, model_dir, config, model_name=None):
    if len(config["model_name"]) or (model_name is not None):
        model_name = config["model_name"][-1] if model_name is None else model_name
    else:
        model_name = None
    # logdir = "logdir/%s/models" % model_dir
    logdir = model_dir+"/models"


    fpath_model = "%s/%s" % (logdir, model_name)
    # print (fpath_model)
    if os.path.exists(fpath_model) and (model_name is not None):
        print ("Loading model: %s" % fpath_model)

        package = torch.load(fpath_model, map_location=lambda storage, loc: storage)
        epoch = package['epoch']+1 if not(package['epoch'] == 'N/A') else 1
        #edit variable names for loading in cpu
        #if not(config["cuda"]):
        for k in package['state_dict'].keys():
            package['state_dict'][k.replace('module.', '', 1)] = package['state_dict'].pop(k)

        state_dict = dict()
        for k in model.state_dict().keys():
            if k in package['state_dict']:
                state_dict[k] = package['state_dict'][k]
        model_state = model.state_dict()
        model_state.update(state_dict)
        model.load_state_dict(model_state)
        print ("done.")
    else:
        epoch = 1
        print ("Pretrained model not found")
    return model_name, epoch

In [None]:
def save(model, model_dir, epoch, step,config):
    logdir = model_dir+"/models"
    mkpath(logdir)
    fname_model = 'gazeNET_%04d_%08d.pth.tar' %(epoch, step)
    file_path = '%s/%s' % (logdir, fname_model)

    torch.save(checkpoint(model, step, epoch), file_path)
    config["model_name"].append(fname_model)
    model_list = config["model_name"][-config['max_to_keep']:]
    remove_list = config["model_name"][:-config['max_to_keep']:]
    for _rm in remove_list:
        fpath_rm = '%s/%s' % (logdir, _rm)
        if os.path.exists(fpath_rm):
            os.remove(fpath_rm)
    config["model_name"] = model_list
    return config

In [None]:
def calc_params(model):
    all_params = OrderedDict()
    params = model.state_dict()

    for _p in params.keys():
        #if not('ih_l0_reverse' in _p):
        all_params[_p] = params[_p].nelement()
    return all_params

In [None]:
class SequenceWise(nn.Module):
    def __init__(self, module):
        super(SequenceWise, self).__init__()
        self.module = module

    def forward(self, x):
        t, n = x.size(0), x.size(1)
        x = x.view(t * n, -1)
        x = self.module(x)
        x = x.view(t, n, -1)
        return x

    def __repr__(self):
        tmpstr = self.__class__.__name__ + ' (\n'
        tmpstr += self.module.__repr__()
        tmpstr += ')'
        return tmpstr

In [None]:
class BatchRNN(nn.Module):
    def __init__(self, input_size, hidden_size, bidirectional=False, batch_norm=True, keep_prob=0.5):
        super(BatchRNN, self).__init__()
        self.batch_norm = batch_norm
        self.bidirectional = bidirectional

        rnn_bias = False if batch_norm else True
        self.rnn = nn.GRU(input_size=input_size,
                          hidden_size=hidden_size,
                          bidirectional=bidirectional,
                          batch_first=True,
                          bias=rnn_bias)
        self.batch_norm_op = SequenceWise(nn.BatchNorm1d(hidden_size))

        self.dropout_op = nn.Dropout(1-keep_prob)

    def forward(self, x):
        x, _ = self.rnn(x)
        x = x.contiguous()
        if self.bidirectional:
            x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1)  # (TxNxH*2) -> (TxNxH) by sum
            x = x.contiguous()
        if self.batch_norm:
            x = self.batch_norm_op(x)
        x = self.dropout_op(x)
        return x

In [None]:
class gazeNET(nn.Module):
    def __init__(self, config, num_classes, seed=220617):
        super(gazeNET, self).__init__()
        torch.manual_seed(seed)
        if (torch.cuda.device_count()>0):
            torch.cuda.manual_seed(seed)

        if 'conv_stack' in config['architecture']:
            ## convolutional stack
            conv_config = config['architecture']['conv_stack']
            conv_stack = []
            #feat_dim = int(math.floor((config['sample_rate'] * 2*config['window_stride']) / 2) + 1)
            feat_dim = 2
            in_channels = 1
            for _conv in conv_config:
                name, out_channels, kernel_size, stride = _conv
                padding = map(lambda x: int(x/2), kernel_size)
                _conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size=tuple(kernel_size), stride=tuple(stride),
                              padding = tuple(padding),
                              bias = False
                              )
                #init_vars.xavier_uniform(conv_op.weight, gain=np.sqrt(2))
                _conv = nn.Sequential(
                    _conv,
                    nn.BatchNorm2d(out_channels),
                    nn.Hardtanh(0, 20, inplace=True),
                    nn.Dropout(1-config['keep_prob']),
                )
                conv_stack.append((name, _conv))
                in_channels = out_channels
                feat_dim = feat_dim/stride[0]+1
            self.conv_stack = nn.Sequential(OrderedDict(conv_stack))
            rnn_input_size = feat_dim * out_channels
        else:
            self.conv_stack = None
            rnn_input_size = 2

        ## RNN stack
        rnn_config = config['architecture']['rnn_stack']
        rnn_stack = []
        for _rnn in rnn_config:
            name, hidden_size, batch_norm, bidirectional = _rnn
            rnn_input_size=int(rnn_input_size)
            _rnn = BatchRNN(input_size=rnn_input_size, hidden_size=hidden_size,
                            bidirectional=bidirectional, batch_norm=batch_norm,
                            keep_prob = config['keep_prob'])
            rnn_stack.append((name, _rnn))
            rnn_input_size = hidden_size
        self.rnn_stack = nn.Sequential(OrderedDict(rnn_stack))

        ## FC stack
        self.fc = nn.Sequential(
            SequenceWise(nn.Linear(hidden_size, num_classes, bias=False)),
        )
    ### forward
    def forward(self, x):
        if self.conv_stack is not None:
            x = self.conv_stack(x)

        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # Collapse feature dimension
        x = x.transpose(1, 2).contiguous()  # TxNxH

        x = self.rnn_stack(x)

        x = self.fc(x)
        return x

# Train

### Arguments and Directory

In [None]:
model_dir='log_dir/model_dev'
num_workers=1
num_epochs=20
seed=220617


dir_data=model_dir+'/data'


### Config Loader

In [None]:
configuration = Config(model_dir+'/config.json')
config=configuration.params

### Cuda and Batch Size

In [None]:
if (torch.cuda.device_count()>0) & config['cuda']:
    batch_size = config["batch_size"]
    batch_size*=torch.cuda.device_count()
    cuda = True
else:
    batch_size = 100
    cuda = False

### Load Training Data

In [None]:
log_writer_train = tf.summary.create_file_writer(model_dir+'/TB/train')
train_file=dir_data+'/data.gen.pkl'
with open(train_file, 'rb') as f:
  X_train = pickle.load(f,encoding="bytes") 

### Load Validation Data

In [None]:
log_writer_val = tf.summary.create_file_writer(model_dir+'/TB/val')
val_file=dir_data+'/data.val_clean.pkl'
with open(val_file, 'rb') as f:
  X_val = pickle.load(f,encoding="bytes") 
  X_val = [_d for _t, _d in X_val]

### Load Generative Training Data

In [None]:
log_writer_train_gen = tf.summary.create_file_writer(model_dir+'/TB/train_gen')
train_gen_file=dir_data+'/data.unpaired_clean.pkl'
with open(train_gen_file, 'rb') as f:
  X_train_gen = pickle.load(f,encoding="bytes")
  X_train_gen = [_d for _t, _d in X_train_gen]

### Train Data


In [None]:
dataset_train = EMDataset(config = config, gaze_data = X_train)

In [None]:
loader_train = GazeDataLoader(dataset_train, batch_size=batch_size,
                              num_workers=num_workers,
                              shuffle=True)

### Validation Data

In [None]:
import copy

config_val = copy.deepcopy(config)
config_val['split_seqs']=False
config_val['batch_size']=1
config_val['augment']=False

In [None]:
dataset_val = EMDataset(config = config_val, gaze_data = X_val)
loader_val = GazeDataLoader(dataset_val, batch_size=1,
                            num_workers=num_workers,
                            shuffle=False)

### Generative Training Data

In [None]:
dataset_train_gen = EMDataset(config = config_val, gaze_data = X_train_gen)
loader_train_gen = GazeDataLoader(dataset_train_gen, batch_size=1,
                                 num_workers=num_workers,
                                 shuffle=False)

### Prepare Model

In [None]:
num_classes = len(config['events'])
model = gazeNET(config, num_classes,seed)
# n_params = model_func.calc_params(model)
n_params = calc_params(model)
print("Number of parameters: %s" %human_format(sum(n_params.values())))

In [None]:
_, epoch_start = load(model, model_dir, config)

if cuda:
    model = torch.nn.DataParallel(model).cuda()

parameters = model.parameters()
optimizer = torch.optim.RMSprop(parameters, lr=config["learning_rate"])

In [None]:
event_stats = np.array([_e for _x in X_train_gen
                           for _e in _x['evt'].tolist()
                           if not(len(_x)<config['seq_len']+1)])
event_stats = convertToOneHot(event_stats-1, len(np.unique(event_stats)))
event_weights = event_stats.sum(0)[:3]
event_weights = event_weights.astype(np.float32)/event_weights.sum()
weights = torch.FloatTensor(1-event_weights[:3])

if cuda:
    weights_cuda = weights.cuda()
    criterion = torch.nn.CrossEntropyLoss(weights_cuda)
else:
    criterion = torch.nn.CrossEntropyLoss(weights)

### Training Model

In [None]:
model.train()
val_score_best = 0
for epoch in range(epoch_start, num_epochs+1): #because we start from 1
  # iterator = tqdm(loader_train)
  end = time.time()
  loss=0
  # for step, data in enumerate(iterator):
  for step, data in enumerate(loader_train):
    global_step = len(loader_train)*(epoch-1) + step

    ##Prepare data
    inputs, targets, input_percentages, target_sizes, _ = data
    t_data = time.time() - end

    t_model_s = time.time()
    inputs = Variable(inputs)
    y_ = Variable(targets)
    if cuda:
        inputs = inputs.cuda()
        y_ = y_.cuda()

    ##Forward Pass
    # print(model)
    y = model(inputs)
    yt, yn = y.size()[:2]
    y = y.view(yt * yn, -1)
    #WARNING: only works for split_seqs=True;
    #i.e. all sequences need to be same exact length
    loss = criterion(y, y_)

    ##Backward pass
    if torch.isnan(loss):
        optimizer.zero_grad()
        loss.backward()
        #torch.nn.utils.clip_grad_norm(model.parameters(), config["gradclip"])
        optimizer.step()
    end = time.time()

    # iterator.set_description('Epoch: %d, Loss: %.3f, t_data = %.3f, t_model:%.3f' % (epoch, loss.item(), t_data, end-t_model_s))

    #%%model persistence
    if not(config['save_every']==0) and (global_step%config['save_every'] == 0):
        global_step = len(loader_train)*(epoch-1) + step
        # model_func.save(model, args.model_dir, epoch, global_step, config)
        save(model, model_dir, epoch, global_step, config)
  print('Epoch: %d, Loss: %.3f, t_data = %.3f, t_model:%.3f' % (epoch, loss.item(), t_data, end-t_model_s))