In [None]:
class AONet:

    def __init__(self, hps: HParameters):
        self.hps = hps
        self.model = None
        self.log_file = None
        self.verbose = hps.verbose


    def initialize(self, cuda_device=None):
        rnd_seed = 12345
        random.seed(rnd_seed)
        np.random.seed(rnd_seed)
        torch.manual_seed(rnd_seed)

        self.model = VASNet()
        self.model.eval()
        self.model.apply(weights_init)##?
        #print(self.model)

        cuda_device = cuda_device or self.hps.cuda_device

        if self.hps.use_cuda:
            print("Setting CUDA device: ",cuda_device)
            torch.cuda.set_device(cuda_device)
            torch.cuda.manual_seed(rnd_seed)

        if self.hps.use_cuda:
            self.model.cuda()

        return
    
    
    def load_datasets(self, datasets = None):
        """
        Loads all h5 datasets from the datasets list into a dictionary self.dataset
        referenced by their base filename
        :param datasets:  List of dataset filenames
        :return:
        """
        if datasets is None:
            datasets = self.hps.datasets

        datasets_dict = {}
        for dataset in datasets:
            _, base_filename = os.path.split(dataset)
            base_filename, _ = os.path.splitext(base_filename)
            print("Loading:", dataset)
            # dataset_name = base_filename.split('_')[2]
            # print("\tDataset name:", dataset_name)
            # datasets_dict[base_filename] = h5py.File(dataset, 'r')
            datasets_dict[base_filename] = dataset

        self.datasets_dict = datasets_dict
        
        return datasets_dict
    
    
    
    def load_model(self, model_filename):
        self.model.load_state_dict(torch.load(model_filename, map_location=lambda storage, loc: storage))
        return

    
    
    
#     def fix_keys(self, keys, dataset_name = None):
#         """
#         :param keys:
#         :return:
#         """
#         # dataset_name = None
#         if len(self.datasets) == 1:
#             dataset_name = next(iter(self.datasets))

#         keys_out = []
#         for key in keys:
#             t = key.split('/')
#             if len(t) != 2:
#                 assert dataset_name is not None, "ERROR dataset name in some keys is missing but there are multiple dataset {} to choose from".format(len(self.datasets))

#                 key_name = dataset_name+'/'+key
#                 keys_out.append(key_name)
#             else:
#                 keys_out.append(key)

#         return keys_out


    


    def load_split_file(self, splits_file):

        self.dataset_name, self.dataset_type, self.splits = parse_splits_filename(splits_file)
        n_folds = len(self.splits)
        self.split_file = splits_file
        print("Loading splits from: ",splits_file)

        return n_folds


#     def select_split(self, split_id):
#         print("Selecting split: ",split_id)

#         self.split_id = split_id
#         n_folds = len(self.splits)
#         assert self.split_id < n_folds, "split_id (got {}) exceeds {}".format(self.split_id, n_folds)

#         split = self.splits[self.split_id]
#         self.train_keys = split['train_keys']
#         self.test_keys = split['test_keys']

#         dataset_filename = self.hps.get_dataset_by_name(self.dataset_name)[0]
#         _,dataset_filename = os.path.split(dataset_filename)
#         dataset_filename,_ = os.path.splitext(dataset_filename)
#         self.train_keys = self.fix_keys(self.train_keys, dataset_filename)
#         self.test_keys = self.fix_keys(self.test_keys, dataset_filename)
#         return



#     def get_data(self, key):
#         key_parts = key.split('/')
#         assert len(key_parts) == 2, "ERROR. Wrong key name: "+key
#         dataset, key = key_parts
#         return self.datasets[dataset][key]

#     def lookup_weights_file(self, data_path):
#         dataset_type_str = '' if self.dataset_type == '' else self.dataset_type + '_'
#         weights_filename = data_path + '/models/{}_{}splits_{}_*.tar.pth'.format(self.dataset_name, dataset_type_str, self.split_id)
#         weights_filename = glob.glob(weights_filename)
#         if len(weights_filename) == 0:
#             print("Couldn't find model weights: ", weights_filename)
#             return ''

#         # Get the first weights filename in the dir
#         weights_filename = weights_filename[0]
#         splits_file = data_path + '/splits/{}_{}splits.json'.format(self.dataset_name, dataset_type_str)

#         return weights_filename, splits_file


    def train(self, output_dir='EX-0'):

        print("Initializing VASNet model and optimizer...")
        self.model.train()

        criterion = nn.MSELoss()

        if self.hps.use_cuda:
            criterion = criterion.cuda()

        parameters = filter(lambda p: p.requires_grad, self.model.parameters())
        
        self.optimizer = torch.optim.Adam(parameters, lr=self.hps.lr[0], weight_decay=self.hps.l2_req)

        print("Starting training...")

        max_val_fscore = 0
        max_val_fscore_epoch = 0
        train_keys = self.train_keys[:]

        lr = self.hps.lr[0]
        for epoch in range(self.hps.epochs_max):

            print("Epoch: {0:6}".format(str(epoch)+"/"+str(self.hps.epochs_max)), end='')
            self.model.train()
            avg_loss = []

            random.shuffle(train_keys)

            for i, key in enumerate(train_keys):
                dataset = self.get_data(key)
                seq = dataset['features'][...]
                seq = torch.from_numpy(seq).unsqueeze(0)
                target = dataset['gt_score'][...]
                target = torch.from_numpy(target).unsqueeze(0)

                # Normalize frame scores
                target -= target.min()
                target /= target.max()

                if self.hps.use_cuda:
                    seq, target = seq.float().cuda(), target.float().cuda()

                seq_len = seq.shape[1]
                y, _ = self.model(seq,seq_len)
                loss_att = 0

                loss = criterion(y, target)
                # loss2 = y.sum()/seq_len
                loss = loss + loss_att
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                avg_loss.append([float(loss), float(loss_att)])

            # Evaluate test dataset
            val_fscore, video_scores = self.eval(self.test_keys)
            if max_val_fscore < val_fscore:
                max_val_fscore = val_fscore
                max_val_fscore_epoch = epoch

            avg_loss = np.array(avg_loss)
            print("   Train loss: {0:.05f}".format(np.mean(avg_loss[:, 0])), end='')
            print('   Test F-score avg/max: {0:0.5}/{1:0.5}'.format(val_fscore, max_val_fscore))

            if self.verbose:
                video_scores = [["No", "Video", "F-score"]] + video_scores
                print_table(video_scores, cell_width=[3,40,8])

            # Save model weights
            path, filename = os.path.split(self.split_file)
            base_filename, _ = os.path.splitext(filename)
            path = os.path.join(output_dir, 'models_temp', base_filename+'_'+str(self.split_id))
            os.makedirs(path, exist_ok=True)
            filename = str(epoch)+'_'+str(round(val_fscore*100,3))+'.pth.tar'
            torch.save(self.model.state_dict(), os.path.join(path, filename))

        return max_val_fscore, max_val_fscore_epoch


#     def eval(self, dataset, keys=None, results_filename=None):

#         self.model.eval()
#         summary = {}
#         att_vecs = {}
        
#         with torch.no_grad(), h5py.File(self.datasets_dict[dataset], 'a') as d:
#             if keys==None:
#                 keys=d.keys()
                
#             for i, key in enumerate(keys):
#                 # data = self.get_data(key)
#                 # seq = self.dataset[key]['features'][...]
#                 seq = d[key]['features'][...]
#                 seq = torch.from_numpy(seq).unsqueeze(0)

#                 if self.hps.use_cuda:
#                     seq = seq.float().cuda()

#                 y, att_vec = self.model(seq, seq.shape[1])
#                 summary[key] = y[0].detach().cpu().numpy()
#                 att_vecs[key] = att_vec.detach().cpu().numpy()

#             f_score, video_scores = self.eval_summary(summary, keys,  d, results_filename=results_filename, metric=self.hps.dataset_name, att_vecs=att_vecs)

#         return f_score, video_scores


#     def eval_summary(self, machine_summary_activations, test_keys, dataset,  results_filename=None, metric='tvsum', att_vecs=None):

#         eval_metric = 'avg' if metric == 'tvsum' else 'max'

#         if results_filename is None:
#             results_filename = 'results/test_result001.h5'
#         fms = []
#         video_scores = []

#         with h5py.File(results_filename, 'w') as h5_res:
        
#             for key_idx, key in enumerate(test_keys):
#                 d = dataset[key]
#                 probs = machine_summary_activations[key]

#                 if 'change_points' not in d:
#                     print("ERROR: No change points in dataset/video ",key)

#                 cps = d['change_points'][...]
#                 num_frames = d['n_frames'][()]
#                 nfps = d['n_frame_per_seg'][...].tolist()
#                 positions = d['picks'][...]
#                 # user_summary = d['user_summary'][...]

#                 machine_summary = generate_summary(probs, cps, num_frames, nfps, positions)
#                 # fm, _, _ = evaluate_summary(machine_summary, user_summary, eval_metric)
#                 # fms.append(fm)

#                 # Reporting & logging
#                 video_scores.append([key_idx + 1, key, "{:.1%}".format(fm)])

#                 if results_filename:
#                     gt = d['gtscore'][...]
#                     h5_res.create_dataset(key + '/score', data=probs)
#                     h5_res.create_dataset(key + '/machine_summary', data=machine_summary)
#                     h5_res.create_dataset(key + '/gtscore', data=gt)
#                     # h5_res.create_dataset(key + '/fm', data=fm)
#                     h5_res.create_dataset(key + '/picks', data=positions)

#                     video_name = key.split('/')[1]
#                     if 'video_name' in d:
#                         video_name = d['video_name'][...]
#                     h5_res.create_dataset(key + '/video_name', data=video_name)

#                     if att_vecs is not None:
#                         h5_res.create_dataset(key + '/att', data=att_vecs[key])

#         mean_fm = np.mean(fms)

#         # Reporting & logging
#         # if results_filename is not None:
#         #     h5_res.close()

#         return mean_fm, video_scores



In [None]:

def train(hps):
    # os.makedirs(hps.output_dir, exist_ok=True)
    # os.makedirs(os.path.join(hps.output_dir, 'splits'), exist_ok=True)
    # os.makedirs(os.path.join(hps.output_dir, 'code'), exist_ok=True)
    # os.makedirs(os.path.join(hps.output_dir, 'models'), exist_ok=True)
    # os.system('cp -f splits/*.json  ' + hps.output_dir + '/splits/')
    # os.system('cp *.py ' + hps.output_dir + '/code/')

    # Create a file to collect results from all splits
    f = open(hps.results_path, 'wt')
           
        
    # for split_filename in hps.split_file:
    #     dataset_name, dataset_type, splits = parse_splits_filename(split_filename)

        # For no augmentation use only a dataset corresponding to the split file
#         datasets = None
#         if dataset_type == '':
#             datasets = hps.get_dataset_by_name(dataset_name)

#         if datasets is None:
#             datasets = hps.datasets

    f_avg = 0
    f = open(hps.split_file)
    splits = json.load(f)
    n_folds = len(splits)
    for split_id in range(n_folds):
        ao = AONet(hps)
        ao.initialize()
        ao.load_datasets(datasets=datasets)
        ao.load_split_file(splits_file=split_filename)
        ao.select_split(split_id=split_id)

        fscore, fscore_epoch = ao.train(output_dir=hps.output_dir)
        f_avg += fscore

        # Log F-score for this split_id
        f.write(split_filename + ', ' + str(split_id) + ', ' + str(fscore) + ', ' + str(fscore_epoch) + '\n')
        f.flush()

        # Save model with the highest F score
        _, log_file = os.path.split(split_filename)
        log_dir, _ = os.path.splitext(log_file)
        log_dir += '_' + str(split_id)
        log_file = os.path.join(hps.output_dir, 'models', log_dir) + '_' + str(fscore) + '.tar.pth'

        os.makedirs(os.path.join(hps.output_dir, 'models', ), exist_ok=True)
        os.system('mv ' + hps.output_dir + '/models_temp/' + log_dir + '/' + str(fscore_epoch) + '_*.pth.tar ' + log_file)
        os.system('rm -rf ' + hps.output_dir + '/models_temp/' + log_dir)

        print("Split: {0:}   Best F-score: {1:0.5f}   Model: {2:}".format(split_filename, fscore, log_file))

    # Write average F-score for all splits to the results.txt file
    f_avg /= n_folds
    f.write(split_filename + ', ' + str('avg') + ', ' + str(f_avg) + '\n')
    f.flush()

    f.close()


## Vasnet

**Packages**

In [18]:
import h5py
import math
import random
import numpy as np
import os.path as osp
import json

import import_ipynb
from Model import VASNet

from ortools.algorithms import pywrapknapsack_solver

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable

### Util modules

In [4]:
def mkdir_if_missing(directory):
    if not osp.exists(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

def write_json(obj, fpath):
    mkdir_if_missing(osp.dirname(fpath))
    with open(fpath, 'w') as f:
        json.dump(obj, f, indent=4, separators=(',', ': '))


In [5]:
def split_random(keys, num_videos, num_train):
    """Random split"""
    train_keys, test_keys = [], []
    rnd_idxs = np.random.choice(range(num_videos), size=num_train, replace=False)
    for key_idx, key in enumerate(keys):
        if key_idx in rnd_idxs:
            train_keys.append(key)
        else:
            test_keys.append(key)

    assert len(set(train_keys) & set(test_keys)) == 0, "Error: train_keys and test_keys overlap"

    return train_keys, test_keys

In [6]:
def parse_splits_filename(splits_filename):
    # Parse split file and count number of k_folds
    spath, sfname = os.path.split(splits_filename)
    sfname, _ = os.path.splitext(sfname)
    dataset_name = sfname.split('_')[0]  # Get dataset name e.g. tvsum
    dataset_type = sfname.split('_')[1]  # augmentation type e.g. aug

    # The keyword 'splits' is used as the filename fields terminator from historical reasons.
    if dataset_type == 'splits':
        # Split type is not present
        dataset_type = ''

    # Get number of discrete splits within each split json file
    with open(splits_filename, 'r') as sf:
        splits = json.load(sf)

    return dataset_name, dataset_type, splits


In [7]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname == 'Linear':
        init.xavier_uniform_(m.weight, gain=np.sqrt(2.0))
        if m.bias is not None:
            init.constant_(m.bias, 0.1)

In [8]:
def knapsack_ortools(values, weights, items, capacity ):
    scale = 1000
    values = np.array(values)
    weights = np.array(weights)
    values = (values * scale).astype(np.int32)
    weights = (weights).astype(np.int32)
    capacity = capacity
    osolver = pywrapknapsack_solver.KnapsackSolver(pywrapknapsack_solver.KnapsackSolver.KNAPSACK_DYNAMIC_PROGRAMMING_SOLVER,'test')
    osolver.Init(values.tolist(), [weights.tolist()], [capacity])
    computed_value = osolver.Solve()
    packed_items = [x for x in range(0, len(weights))
                    if osolver.BestSolutionContains(x)]

    return packed_items

In [9]:
def generate_summary(ypred, cps, n_frames, nfps, positions, proportion=0.15, method='knapsack'):
    """Generate keyshot-based video summary i.e. a binary vector.
    Args:
    ---------------------------------------------
    - ypred: predicted importance scores.
    - cps: change points, 2D matrix, each row contains a segment.
    - n_frames: original number of frames.
    - nfps: number of frames per segment.
    - positions: positions of subsampled frames in the original video.
    - proportion: length of video summary (compared to original video length).
    - method: defines how shots are selected, ['knapsack', 'rank'].
    """
    n_segs = cps.shape[0]
    frame_scores = np.zeros((n_frames), dtype=np.float32)
    if positions.dtype != int:
        positions = positions.astype(np.int32)
    if positions[-1] != n_frames:
        positions = np.concatenate([positions, [n_frames]])
    for i in range(len(positions) - 1):
        pos_left, pos_right = positions[i], positions[i+1]
        if i == len(ypred):
            frame_scores[pos_left:pos_right] = 0
        else:
            frame_scores[pos_left:pos_right] = ypred[i]

    seg_score = []
    for seg_idx in range(n_segs):
        start, end = int(cps[seg_idx,0]), int(cps[seg_idx,1]+1)
        scores = frame_scores[start:end]
        seg_score.append(float(scores.mean()))

    limits = int(math.floor(n_frames * proportion))

    if method == 'knapsack':
        #picks = knapsack_dp(seg_score, nfps, n_segs, limits)
        picks = knapsack_ortools(seg_score, nfps, n_segs, limits)
    elif method == 'rank':
        order = np.argsort(seg_score)[::-1].tolist()
        picks = []
        total_len = 0
        for i in order:
            if total_len + nfps[i] < limits:
                picks.append(i)
                total_len += nfps[i]
    else:
        raise KeyError("Unknown method {}".format(method))

    summary = np.zeros((1), dtype=np.float32) # this element should be deleted
    for seg_idx in range(n_segs):
        nf = nfps[seg_idx]
        if seg_idx in picks:
            tmp = np.ones((nf), dtype=np.float32)
        else:
            tmp = np.zeros((nf), dtype=np.float32)
        summary = np.concatenate((summary, tmp))

    summary = np.delete(summary, 0) # delete the first element
    return summary


def evaluate_summary(machine_summary, user_summary, eval_metric='avg'):
    """Compare machine summary with user summary (keyshot-based).
    Args:
    --------------------------------
    machine_summary and user_summary should be binary vectors of ndarray type.
    eval_metric = {'avg', 'max'}
    'avg' averages results of comparing multiple human summaries.
    'max' takes the maximum (best) out of multiple comparisons.
    """
    machine_summary = machine_summary.astype(np.float32)
    user_summary = user_summary.astype(np.float32)
    n_users,n_frames = user_summary.shape

    # binarization
    machine_summary[machine_summary > 0] = 1
    user_summary[user_summary > 0] = 1

    if len(machine_summary) > n_frames:
        machine_summary = machine_summary[:n_frames]
    elif len(machine_summary) < n_frames:
        zero_padding = np.zeros((n_frames - len(machine_summary)))
        machine_summary = np.concatenate([machine_summary, zero_padding])

    f_scores = []
    prec_arr = []
    rec_arr = []

    for user_idx in range(n_users):
        gt_summary = user_summary[user_idx,:]
        overlap_duration = (machine_summary * gt_summary).sum()
        precision = overlap_duration / (machine_summary.sum() + 1e-8)
        recall = overlap_duration / (gt_summary.sum() + 1e-8)
        if precision == 0 and recall == 0:
            f_score = 0.
        else:
            f_score = (2 * precision * recall) / (precision + recall)
        f_scores.append(f_score)
        prec_arr.append(precision)
        rec_arr.append(recall)

    if eval_metric == 'avg':
        final_f_score = np.mean(f_scores)
        final_prec = np.mean(prec_arr)
        final_rec = np.mean(rec_arr)
    elif eval_metric == 'max':
        final_f_score = np.max(f_scores)
        max_idx = np.argmax(f_scores)
        final_prec = prec_arr[max_idx]
        final_rec = rec_arr[max_idx]
    
    return final_f_score, final_prec, final_rec


### Training Hyper Parameters

In [10]:
class HParameters:
    
    def __init__(self, args):
        
        self.verbose = args['verbose']
        self.use_cuda = args['use_cuda']
        self.cuda_device = args['cuda_device']
        self.max_summary_length = args['max_summary_length']

        self.l2_req = 0.00001
        self.lr_epochs = [0]
        self.lr = [0.00005]
        self.epochs_max = 300
        self.train_batch_size = 1

        self.dataset=args['dataset']
        self.results_path = args['results_path']
        self.num_splits = args['num_splits']
        self.split_file = args['split_file']
        self.train_percent = args['train_percent']
        
        if 'model_path' in args:
            self.model_path = args['model_path']
        else:
            self.model_path = None
        return


    def create_split(self):
        print("Loading dataset from {}".format(self.dataset))
        
        with h5py.File(self.dataset, 'r') as dataset:
            keys = dataset.keys()
            num_videos = len(keys)
            num_train = int(math.ceil(num_videos * self.train_percent))
            num_test = num_videos - num_train

            print("Split breakdown: # total videos {}. # train videos {}. # test videos {}".format(num_videos, num_train, num_test))
            splits = []

            for split_idx in range(self.num_splits):
                train_keys, test_keys = split_random(keys, num_videos, num_train)
                splits.append({
                    'train_keys': train_keys,
                    'test_keys': test_keys,
                    })

            # saveto = osp.join(self.split_file)
            write_json(splits, self.split_file)
            print("Splits saved to {}".format(self.split_file))

        
    def __str__(self):
        vars = [attr for attr in dir(self) if not callable(getattr(self,attr)) and not (attr.startswith("__") or attr.startswith("_"))]

        info_str = ''
        for i, var in enumerate(vars):
            val = getattr(self, var)
            if isinstance(val, Variable):
                val = val.data.cpu().numpy().tolist()[0]
            info_str += '['+str(i)+'] '+var+': '+str(val)+'\n'

        return info_str
    
    
#     def load_from_args(self, args):
#         for key in args:
#             val = args[key]
#             if val is not None:
#                 if hasattr(self, key) and isinstance(getattr(self, key), list):
#                     val = val.split()

#                 setattr(self, key, val)

#     def get_dataset_by_name(self, dataset_name):
#         for d in self.datasets:
#             if dataset_name in d:
#                 return [d]
#         return None

    


### Trainer

In [41]:
class Trainer:
    def __init__(self, hps: HParameters):
        self.hps = hps
        self.model = VASNet()
        self.verbose = True
        self.criterion = nn.MSELoss()
        self.show_every = 1
        

        
    def init_model(self):
        if self.hps.model_path:
            self.model.load_state_dict(torch.load(self.hps.model_path, map_location=lambda storage, loc: storage))
        else:
            self.model.eval()
            self.model.apply(weights_init)
        
        
    def train(self, train_keys):
        losses=[]
        for i, key in enumerate(train_keys):
            with h5py.File(self.hps.dataset) as d:
                seq= d[key]['features'][...]
                target = d[key]['gt_score'][...]
                target = target.astype(float)
                
            # seq = dataset['features'][...]
            seq = torch.from_numpy(seq).unsqueeze(0)
            # target = dataset['gtscore'][...]
            target = torch.from_numpy(target).unsqueeze(0)

            # Min-Max Normalize frame scores
            target -= target.min()
            target /= target.max()
            

            if self.hps.use_cuda:
                seq, target = seq.float().cuda(), target.float().cuda()

            seq_len = seq.shape[1]
            y, _ = self.model(seq,seq_len)
            # print(y)
            loss_att = 0

            loss = self.criterion(y, target.float())
            loss = loss + loss_att
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            losses.append(float(loss))
            
        return np.mean(np.array(losses))

    def video_fscore(self, machine_summary_activations, test_keys, metric='tvsum', att_vecs=None):
        eval_metric = 'avg' if metric == 'tvsum' else 'max'

        # if results_filename is not None:
        #     h5_res = h5py.File(results_filename, 'w')

        fms = []
        video_scores = []
        for key_idx, key in enumerate(test_keys):
            
            probs = machine_summary_activations[key]


            with h5py.File(self.hps.dataset,'r') as d:
                cps = d[key]['change_points'][...]
                num_frames = d[key]['n_frames'][()]
                nfps = d[key]['n_frame_per_seg'][...].tolist()
                positions = d[key]['picks'][...]
                user_summary = d[key]['user_summary'][...]

            machine_summary = generate_summary(probs, cps, num_frames, nfps, positions)
            fm, _, _ = evaluate_summary(machine_summary, user_summary, eval_metric)
            fms.append(fm)

            # Reporting & logging
            video_scores.append([key_idx + 1, key, "{:.1%}".format(fm)])
            
        mean_fm = np.mean(fms)
        
        return mean_fm, video_scores

    def validate(self, test_keys):
        self.model.eval()
        summary = {}
        att_vecs = {}
        with torch.no_grad():
            for i, key in enumerate(test_keys):
                with h5py.File(self.hps.dataset) as d:
                    seq = d[key]['features'][...]
                    
                seq = torch.from_numpy(seq).unsqueeze(0)

                if self.hps.use_cuda:
                    seq = seq.float().cuda()

                y, att_vec = self.model(seq, seq.shape[1])
                summary[key] = y[0].detach().cpu().numpy()
                att_vecs[key] = att_vec.detach().cpu().numpy()

        f_score, video_scores = self.video_fscore(summary, test_keys, att_vecs=att_vecs)
        return f_score, video_scores
        
    def run(self):
        print("Initializing VASNet model and optimizer...")
        self.init_model()
        self.model.train()

        if self.hps.use_cuda:
            self.criterion = self.criterion.cuda()

        parameters = filter(lambda p: p.requires_grad, self.model.parameters())
        self.optimizer = torch.optim.Adam(parameters, lr=self.hps.lr[0], weight_decay=self.hps.l2_req)
        
        lr = self.hps.lr[0]
        
        f = open(hps.split_file)
        splits = json.load(f)
        n_folds = len(splits)
        
        print("Starting training...")
        for split in splits:
            max_val_fscore = 0
            max_val_fscore_epoch = 0
            train_keys = split['train_keys']
            test_keys = split['test_keys']

            epoch_losses=[]
            for epoch in range(self.hps.epochs_max):

                print("Epoch: {0:6}".format(str(epoch)+"/"+str(self.hps.epochs_max)), end='')
                self.model.train()

                random.shuffle(train_keys)   
                loss = self.train(train_keys)
                epoch_losses.append(loss)
                
                
                # Evaluate test dataset
                val_fscore, video_scores = self.validate(test_keys)
                if max_val_fscore < val_fscore:
                    max_val_fscore = val_fscore
                    max_val_fscore_epoch = epoch
                
                if epoch%self.show_every==0:
                    print(f'Epoch:{epoch}, Loss:{loss}')

            # avg_loss = np.array(epoch_losses)
            print("   Train loss: {0:.05f}".format(np.mean(np.array(epoch_losses))), end='')
            # print('   Test F-score avg/max: {0:0.5}/{1:0.5}'.format(val_fscore, max_val_fscore))

            if self.verbose:
                video_scores = [["No", "Video", "F-score"]] + video_scores
                print_table(video_scores, cell_width=[3,40,8])

        return max_val_fscore, max_val_fscore_epoch
    
    def save_model(self, name):
        # Save model weights
        filename = name+'_'+str(epoch)+'_'+splitn+'.pth.tar'
        torch.save(self.model.state_dict(), os.path.join('models', filename))
        

**Train**

In [16]:
args={
    'results_path':'training_results.txt',
    'num_splits':5,
    'split_file':'splits/test_split1.json',
    'dataset': '../../Preprocessing/extracted_features/normal/TVSum.h5',
    'train_percent':0.8,
    'verbose':True,
    'use_cuda' : False,
    'cuda_device': None,
    'max_summary_length': 0.15
    
}

In [37]:
hps = HParameters(args)
# hps.load_from_args(args.__dict__)
hps.create_split()

Loading dataset from ../../Preprocessing/extracted_features/normal/TVSum.h5
Split breakdown: # total videos 50. # train videos 40. # test videos 10
Splits saved to splits/test_split1.json


In [40]:
trainer = Trainer(hps)
trainer.run()

Initializing VASNet model and optimizer...
Starting training...
Epoch: 0/300 Epoch:0, Loss:0.2202379021793604
Epoch: 1/300 Epoch:1, Loss:0.20131014492362737
Epoch: 2/300 Epoch:2, Loss:0.1933257630094886
Epoch: 3/300 Epoch:3, Loss:0.17167237289249898
Epoch: 4/300 Epoch:4, Loss:0.16181218046694995
Epoch: 5/300 Epoch:5, Loss:0.1554220153018832
Epoch: 6/300 Epoch:6, Loss:0.14790790136903526
Epoch: 7/300 Epoch:7, Loss:0.14590134508907796
Epoch: 8/300 Epoch:8, Loss:0.13863646406680347
Epoch: 9/300 Epoch:9, Loss:0.1390377800911665
Epoch: 10/300Epoch:10, Loss:0.13635272160172462
Epoch: 11/300Epoch:11, Loss:0.13643947867676615
Epoch: 12/300Epoch:12, Loss:0.13478460973128675
Epoch: 13/300Epoch:13, Loss:0.13450990822166203
Epoch: 14/300Epoch:14, Loss:0.13215046357363464
Epoch: 15/300Epoch:15, Loss:0.13190546222031116
Epoch: 16/300Epoch:16, Loss:0.1320596646517515
Epoch: 17/300Epoch:17, Loss:0.13363236943259835
Epoch: 18/300Epoch:18, Loss:0.1306498414836824
Epoch: 19/300Epoch:19, Loss:0.1333784574

KeyboardInterrupt: 

In [None]:
trainer.save_model()

---------------------------