In [1]:
def _train_model(train_data, valid_data, test_data,saved_data_file_path=None,checkpoint_path=None):
    model = model(vocab, args)
    init_state_dict = None
    if init_state_dict is not None:
        model.load_state_dict(init_state_dict)
    model.to(vocab.device)

    train_dataset, valid_dataset, test_dataset = _load_and_cache_data(train_data, valid_data, test_data,
                                                                      vocab, args, logger, saved_data_file_path)

    logger.info("{} instances with {} tokens, {} in the train dataset"
                .format(train_dataset.size, train_dataset.n_total_tokens,
                        ", ".join(["Level_{} with {} labels".format(level, len(train_dataset.labels[level]))
                                   for level in range(vocab.n_level())])))

    logger.info("{} instances with {} tokens, {} in the valid dataset"
                .format(valid_dataset.size, valid_dataset.n_total_tokens,
                        ", ".join(["Level_{} with {} labels".format(level, len(valid_dataset.labels[level]))
                                   for level in range(vocab.n_level())])))

    logger.info("{} instances with {} tokens, {} in the test dataset"
                .format(test_dataset.size, test_dataset.n_total_tokens,
                        ", ".join(["Level_{} with {} labels".format(level, len(test_dataset.labels[level]))
                                   for level in range(vocab.n_level())])))

    train_dataloader = TextDataLoader(dataset=train_dataset, vocab=vocab, batch_size=args.batch_size)

    valid_dataloader = TextDataLoader(dataset=valid_dataset, vocab=vocab, batch_size=args.batch_size)

    test_dataloader = TextDataLoader(dataset=test_dataset, vocab=vocab, batch_size=args.batch_size)

    # Train the model
    if args.optimiser.lower() == "adagrad":
        optimiser = optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
    elif args.optimiser.lower() == "adam":
        betas = (0.9, 0.999)
        optimiser = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                               lr=args.lr, betas=betas, weight_decay=args.weight_decay)
    elif args.optimiser.lower() == "adamw":
        betas = (0.9, 0.999)
        optimiser = AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                          lr=args.lr, betas=betas, weight_decay=args.weight_decay)
    elif args.optimiser.lower() == "sgd":
        optimiser = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
                              lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
    elif args.optimiser.lower() == "adadelta":
        optimiser = optim.Adadelta(filter(lambda p: p.requires_grad, model.parameters()),
                                   lr=args.lr, weight_decay=args.weight_decay)
    else:
        raise NotImplementedError

    lr_plateau = None
    if args.use_lr_scheduler:
        lr_plateau = optim.lr_scheduler.ReduceLROnPlateau(optimiser,
                                                          mode="max",
                                                          factor=args.lr_scheduler_factor,
                                                          patience=args.lr_scheduler_patience,
                                                          min_lr=0.0001)
    if args.multilabel:
        criterions = [nn.BCEWithLogitsLoss() for _ in range(vocab.n_level())]
    else:
        criterions = [nn.CrossEntropyLoss() for _ in range(vocab.n_level())]

    trainer = Trainer(model=model,
                      train_dataloader=train_dataloader,
                      valid_dataloader=valid_dataloader,
                      test_dataloader=test_dataloader,
                      criterions=criterions,
                      optimiser=optimiser,
                      lr_scheduler=lr_plateau,
                      vocab=vocab,
                      logger=logger,
                      args=args,
                      checkpoint_path=checkpoint_path)
    best_model, scores = trainer.train(n_epoch=args.n_epoch, patience=args.patience)

    evaluator = Evaluator(model=best_model,
                          vocab=vocab,
                          criterions=criterions,
                          n_training_labels=get_n_training_labels(train_dataloader))

    del model, lr_plateau, optimiser, evaluator, trainer, criterions
    return best_model, scores  # either on valid or test

ModuleNotFoundError: No module named 'src'

In [None]:
class Trainer:
    def __init__(self, model: nn.Module,
                 train_dataloader: TextDataLoader,
                 valid_dataloader: TextDataLoader,
                 test_dataloader: TextDataLoader,
                 criterions,
                 optimiser,
                 lr_scheduler,
                 vocab,
                 checkpoint_path,
                 ):
        """
        The initialisation model
        :param model: The machine learning model
        :param train_dataloader: Training dataloader
        :param valid_dataloader: Validation dataloader
        :param test_dataloader: Test dataloader
        :param criterions: Criterion to generate loss
        :param optimiser: Adam/AdamW ...
        :param lr_scheduler: Reduce 10% of learning rate every 5 epochs
        :param vocab: Vocabulary
        :param logger:
        :param args:
        :param checkpoint_path: Path where the model is saved
        """
        self.model = model
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.test_dataloader = test_dataloader
        self.criterions = criterions
        self.optimiser = optimiser
        self.lr_scheduler = lr_scheduler
        self.vocab = vocab
        self.args = args
        self.save_best_model = save_best_model
        self.use_regularisation = use_regularisation
        self.penalisation_coeff = penalisation_coeff

        self.multilabel = multilabel
        self.save_results = save_results
        self.checkpoint_path = checkpoint_path
        self.n_training_labels = get_n_training_labels(train_dataloader)

        if self.save_results:
            self.saved_result_path = result_path
        self.saved_last_model_path = None
        if self.save_best_model:
            self.best_model_path = best_model_path

        saved_model_dir = os.path.dirname(self.best_model_path)
        if not os.path.exists(saved_model_dir):
            os.makedirs(saved_model_dir)

        self.main_metric = main_metric
        self.metric_level = "level_{}".format(metric_level)
        if metric_level < 0:
            self.metric_level = "average"

        self.start_epoch = 0
        self.best_val = None
        self.saved_test_scores = None
        self.best_epoch_num = 1

    def train_single_epoch(self, index):
        """
        This is to train a single epoch
        :param index: epoch index
        :return: scores
        """
        self.model.train()

        if bool(self.args.shuffle_data):
            self.train_dataloader.dataset.shuffle_data()
        losses = []
        true_labels = [[] for _ in range(self.vocab.n_level())]
        pred_probs = [[] for _ in range(self.vocab.n_level())]
        ids = []
        all_loss_list = []
        progress_bar = tqdm(self.train_dataloader, unit="batches", desc="Training at epoch #{}".format(index))
        progress_bar.clear()
        self.optimiser.zero_grad()
        batch_id = 0

        for text_batch, label_batch, length_batch, id_batch in progress_bar:
            batch_id += 1
            text_batch = text_batch.to(device)
            for idx in range(len(label_batch)):
                label_batch[idx] = label_batch[idx].to(device)

            if type(length_batch) == list:
                for i in range(len(length_batch)):
                    length_batch[i] = length_batch[i].to(device)
            else:
                length_batch = length_batch.to(device)

            output, attn_weights = self.model(text_batch, length_batch)
            loss_list = []

            for level in range(len(output)):
                level_labels = label_batch[level]
                true_labels[level].extend(level_labels.cpu().numpy())
                loss_list.append(self.criterions[level](output[level], level_labels))

            for level in range(len(loss_list)):
                if len(all_loss_list) < len(loss_list):
                    all_loss_list.append([loss_list[level].item()])
                else:
                    all_loss_list[level].append(loss_list[level].item())

            ids.extend(id_batch)
            for level in range(len(output)):
                if self.multilabel:
                    output[level] = torch.sigmoid(output[level])
                    output[level] = output[level].detach().cpu().numpy()
                    pred_probs[level].extend(output[level])
                else:
                    output[level] = torch.softmax(output[level], 1)
                    output[level] = output[level].detach().cpu().numpy()
                    pred_probs[level].extend(output[level].tolist())
            loss = get_loss(loss_list, self.n_training_labels)
            loss.backward()
            losses.append(loss.item())

            self.optimiser.step()
            self.optimiser.zero_grad()

        scores = OrderedDict()
        for level in range(len(output)):
            if self.args.save_results_on_train:
                scores["level_{}".format(level)] = calculate_eval_metrics(ids, true_labels[level],
                                                                          pred_probs[level], self.multilabel)
            else:
                scores["level_{}".format(level)] = {}
            scores["level_{}".format(level)]["loss"] = -np.mean(all_loss_list[level]).item()

        scores["average"] = average_scores(scores)
        scores["average"]["loss"] = np.mean(losses).item()
        progress_bar.refresh(True)
        progress_bar.clear(True)
        progress_bar.close()

        return scores

    def calculate_penalisation(self, attn_weights, batch_size):
        """
        This is the penalisation for the self attention
        :param attn_weights:
        :param batch_size:
        :return:
        """
        transposed_attn_weights = attn_weights.transpose(1, 2)
        identity = torch.eye(attn_weights.size(1))
        identity = Variable(identity.unsqueeze(0).expand(batch_size, attn_weights.size(1), attn_weights.size(1))).to(
            device)
        penal = AttentionLayer.l2_matrix_norm(attn_weights @ transposed_attn_weights - identity)
        return penal

    def save_checkpoint(self, state, is_best):
        torch.save(state, self.checkpoint_path)
        if is_best:
            shutil.copyfile(self.checkpoint_path, self.best_model_path)

    @staticmethod
    def format_number(number):
        return abs(round(number, ndigits=ndigits))

    def train(self,
              n_epoch: int = 100,
              patience: int = 5):

        if self.args.resume_training:
            if os.path.isfile(self.checkpoint_path):
                checkpoint = torch.load(self.checkpoint_path)
                self.start_epoch = checkpoint['epoch']
                self.best_val = checkpoint['best_val']
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimiser.load_state_dict(checkpoint['optimiser'])
                self.saved_test_scores = checkpoint['test_scores']
                self.best_epoch_num = checkpoint['best_epoch_num']
                if self.lr_scheduler is not None:
                    self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

        best_valid_scores = self.best_val
        saved_test_scores = self.saved_test_scores
        saved_train_scores = None
        check_to_stop = 0
        best_epoch_num = self.best_epoch_num
        # best_state_dict = None
        evaluator = Evaluator(self.model, self.vocab, self.criterions, self.n_training_labels)
        for e in range(self.start_epoch + 1, n_epoch + 1):
            train_scores = self.train_single_epoch(e)
            epoch_loss = train_scores["average"]["loss"]
            valid_scores = evaluator.evaluate(self.valid_dataloader)
            test_scores = evaluator.evaluate(self.test_dataloader)

            if self.lr_scheduler is not None:
                self.lr_scheduler.step(valid_scores[self.metric_level][self.main_metric])

            is_best = False
            if best_valid_scores is None or best_valid_scores[self.metric_level][self.main_metric] < \
                    valid_scores[self.metric_level][self.main_metric]:
                best_valid_scores = valid_scores
                saved_test_scores = test_scores
                saved_train_scores = train_scores
                best_epoch_num = e
                if self.save_best_model:
                    is_best = True

            lr_scheduler_state_dict = None
            if self.lr_scheduler is not None:
                lr_scheduler_state_dict = self.lr_scheduler.state_dict()

            self.save_checkpoint({
                'epoch': e,
                'state_dict': self.model.state_dict(),
                'best_val': best_valid_scores,
                'test_scores': saved_test_scores,
                'optimiser': self.optimiser.state_dict(),
                'best_epoch_num': best_epoch_num,
                'lr_scheduler': lr_scheduler_state_dict
            }, is_best)

            if check_to_stop > patience > 0:
                break


        if self.save_results:
            import pickle
            results = {"train": saved_train_scores, "valid": best_valid_scores, "test": saved_test_scores,
                       "params": self.args, "index2label": self.vocab.index2label}

            with open(self.saved_result_path, 'wb') as f:
                pickle.dump(results, f, pickle.HIGHEST_PROTOCOL)

        if self.save_best_model and os.path.isfile(self.best_model_path):
            best_model = torch.load(self.best_model_path)
            self.model.load_state_dict(best_model['state_dict'])
        return self.model, best_valid_scores