In [None]:
# Do the training and testing
import os
os.getcwd()

In [None]:
# Make sure the source code auto reloads into the kernel
%load_ext autoreload
%autoreload 2

In [None]:
# To help preventing shared maemory errors
!ulimit -n 500000
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

In [None]:
import random
import pickle

from Utils.logger import logger
from train_and_eval import test_model
from train_and_eval import train_model
from train_and_eval import create_config
from train_and_eval import visualize_model
from DatasetCreation.helperFunctions import remove_hidden_dir

In [None]:
# Initial set-up
data_path = 'data'
vertical = 'movie'
attributes = ['title', 'director', 'genre', 'mpaa_rating']

# Define the number of training epochs
num_train_epochs = 10

In [None]:
websites = [dirname.split('(')[0] for dirname in remove_hidden_dir(os.listdir(os.path.join(data_path, vertical)))]
logger.info(f'There are: {len(websites)} vertical web-sites available:\n{websites}')

In [None]:
# Set training, validation and testing sites
train_websites = ['movie-yahoo', 'movie-msn', 'movie-rottentomatoes', 'movie-allmovie', 'movie-hollywood'] #['movie-yahoo', 'movie-msn', 'movie-rottentomatoes', 'movie-allmovie', 'movie-hollywood', 'movie-iheartmovies', 'movie-amctv']
val_websites = ['movie-imdb', 'movie-metacritic']
test_websites = ['movie-boxofficemojo']

In [None]:
# Create the model config
config = create_config(train_websites=train_websites, val_websites=val_websites, test_websites=test_websites, attributes=attributes)

In [None]:
# Remove/rename the old ./data/weights.ckpt file if it is already present!
import time
old_weights_file_name = os.path.join('data', 'weights.ckpt')

# Check if the file exists and rename
if os.path.isfile(old_weights_file_name):
    new_weights_file_name = os.path.join('data', f'weights_{time.time()}.ckpt')
    logger.warning(f'The previous SimpModel weights file already exists, renaming to: {new_weights_file_name}')
    os.rename(old_weights_file_name, new_weights_file_name)
else:
    logger.info(f'The previous model weights file is not present, safe to train!')

In [None]:
logger.info(f'Starting model training')
model = train_model(config, num_train_epochs)
logger.info(f'SimpDOM model training is done!')

In [None]:
logger.info(f'Visualising the model')
visualize_model(model)

In [None]:
logger.info(f'Training is finished, starting predicting')
avg_pr_dict = test_model(config, model)
logger.info(f'Test predictions result: {avg_pr_dict}')

In [None]:
# Check the accuracy on the data used for training
avg_pr_dict = test_model(config, test_websites=['auto-aol', 'auto-yahoo'])
logger.info(f'Test predictions result: {avg_pr_dict}')