## Load the config file

In [1]:
from graph4nlp.pytorch.modules.config import get_basic_args
from graph4nlp.pytorch.modules.utils.config_utils import update_values, get_yaml_config
import json
def get_args():
    config = {'dataset_yaml': "./exp_2_emb_strategy_config.yaml",
              'learning_rate': 1e-3,
              'gpuid': -1,
              'seed': 123, 
              'init_weight': 0.08,
              'weight_decay': 0, 
              'max_epochs': 200, 
              'min_freq': 1,
              'grad_clip': 5,
              'batch_size': 80,
              'share_vocab': True,
              'pretrained_word_emb_name': '6B',
              'checkpoint_save_path': "./checkpoint_save",
              'beam_size': 1
              }
    our_args = get_yaml_config(config['dataset_yaml'])
    template = get_basic_args(graph_construction_name=our_args["graph_construction_name"],
                              graph_embedding_name=our_args["graph_embedding_name"],
                              decoder_name=our_args["decoder_name"])
    update_values(to_args=template, from_args_list=[our_args, config])
    return template

# show our config
cfg_g2t = get_args()
from pprint import pprint
pprint(cfg_g2t)
experiment_result_file = "exp_2_emb_strategy_results.json"

{'batch_size': 80,
 'beam_size': 1,
 'checkpoint_save_path': './checkpoint_save',
 'dataset_yaml': './exp_2_emb_strategy_config.yaml',
 'decoder_args': {'rnn_decoder_private': {'max_decoder_step': 35,
                                          'max_tree_depth': 8,
                                          'use_sibling': False},
                  'rnn_decoder_share': {'attention_type': 'uniform',
                                        'dropout': 0.3,
                                        'fuse_strategy': 'concatenate',
                                        'graph_pooling_strategy': None,
                                        'hidden_size': 300,
                                        'input_size': 300,
                                        'rnn_emb_input_size': 300,
                                        'rnn_type': 'lstm',
                                        'teacher_forcing_rate': 1.0,
                                        'use_copy': True,
                             

In [2]:
import copy
import random
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm

from graph4nlp.pytorch.datasets.mawps import MawpsDatasetForTree
from graph4nlp.pytorch.models.graph2tree import Graph2Tree
from graph4nlp.pytorch.modules.utils.tree_utils import Tree

from utils import convert_to_string, compute_tree_accuracy, prepare_ext_vocab

  from tqdm.autonotebook import tqdm


In [3]:
class Mawps:
    def __init__(self, opt=None):
        super(Mawps, self).__init__()
        self.opt = opt

        seed = self.opt["seed"]
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        if self.opt["gpuid"] == -1:
            self.device = torch.device("cpu")
        else:
            self.device = torch.device("cuda:{}".format(self.opt["gpuid"]))

        self.use_copy = self.opt["decoder_args"]["rnn_decoder_share"]["use_copy"]
        self.use_share_vocab = self.opt["graph_construction_args"]["graph_construction_share"][
            "share_vocab"
        ]
        self.data_dir = self.opt["graph_construction_args"]["graph_construction_share"]["root_dir"]

        self._build_dataloader()
        self._build_model()
        self._build_optimizer()

    def _build_dataloader(self):
        para_dic = {
            "root_dir": self.data_dir,
            "word_emb_size": self.opt["graph_initialization_args"]["input_size"],
            "topology_subdir": self.opt["graph_construction_args"]["graph_construction_share"][
                "topology_subdir"
            ],
            "edge_strategy": self.opt["graph_construction_args"]["graph_construction_private"][
                "edge_strategy"
            ],
            "graph_name": self.opt["graph_construction_args"]["graph_construction_share"][
                "graph_name"
            ],
            "share_vocab": self.use_share_vocab,
            "enc_emb_size": self.opt["graph_initialization_args"]["input_size"],
            "dec_emb_size": self.opt["decoder_args"]["rnn_decoder_share"]["input_size"],
            "dynamic_init_graph_name": self.opt["graph_construction_args"][
                "graph_construction_private"
            ].get("dynamic_init_graph_name", None),
            "min_word_vocab_freq": self.opt["min_freq"],
            "pretrained_word_emb_name": self.opt["pretrained_word_emb_name"]
        }

        dataset = MawpsDatasetForTree(**para_dic)

        self.train_data_loader = DataLoader(
            dataset.train,
            batch_size=self.opt["batch_size"],
            shuffle=True,
            num_workers=0,
            collate_fn=dataset.collate_fn,
        )
        self.test_data_loader = DataLoader(
            dataset.test, batch_size=1, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn
        )
        self.valid_data_loader = DataLoader(
            dataset.val, batch_size=1, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn
        )
        self.vocab_model = dataset.vocab_model
        self.src_vocab = self.vocab_model.in_word_vocab
        self.tgt_vocab = self.vocab_model.out_word_vocab
        self.share_vocab = self.vocab_model.share_vocab if self.use_share_vocab else None

    def _build_model(self):
        """For encoder-decoder"""
        self.model = Graph2Tree.from_args(self.opt, vocab_model=self.vocab_model)
        self.model.init(self.opt["init_weight"])
        self.model.to(self.device)

    def _build_optimizer(self):
        optim_state = {
            "learningRate": self.opt["learning_rate"],
            "weight_decay": self.opt["weight_decay"],
        }
        parameters = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = optim.Adam(
            parameters, lr=optim_state["learningRate"], weight_decay=optim_state["weight_decay"]
        )

    def train_epoch(self, epoch):
        loss_to_print = 0
        num_batch = len(self.train_data_loader)
        for _, data in tqdm(
            enumerate(self.train_data_loader),
            desc=f"Epoch {epoch:02d}",
            total=len(self.train_data_loader),
        ):
            batch_graph, batch_tree_list, batch_original_tree_list = (
                data["graph_data"],
                data["dec_tree_batch"],
                data["original_dec_tree_batch"],
            )
            batch_graph = batch_graph.to(self.device)
            self.optimizer.zero_grad()
            oov_dict = (
                prepare_ext_vocab(batch_graph, self.src_vocab, self.device) if self.use_copy else None
            )

            if self.use_copy:
                batch_tree_list_refined = []
                for item in batch_original_tree_list:
                    tgt_list = oov_dict.get_symbol_idx_for_list(item.strip().split())
                    tgt_tree = Tree.convert_to_tree(tgt_list, 0, len(tgt_list), oov_dict)
                    batch_tree_list_refined.append(tgt_tree)
            loss = self.model(
                batch_graph,
                batch_tree_list_refined if self.use_copy else batch_tree_list,
                oov_dict=oov_dict,
            )
            loss.backward()
            torch.nn.utils.clip_grad_value_(self.model.parameters(), self.opt["grad_clip"])
            self.optimizer.step()
            loss_to_print += loss
        return loss_to_print / num_batch

    def train(self):
        best_acc = (-1, -1)
        best_model = None
        result_data = {}
        
        train_data = []
        print("-------------\nStarting training.")
        for epoch in range(1, self.opt["max_epochs"] + 1):
            self.model.train()
            loss_to_print = self.train_epoch(epoch)
            print("epochs = {}, train_loss = {:.3f}".format(epoch, loss_to_print))
            if epoch > 1 and epoch % 10 == 0:
                test_acc = self.eval(self.model, mode="test")
                val_acc = self.eval(self.model, mode="val")
                epoch_data = dict(train_loss=loss_to_print.item(), val_acc=val_acc, test_acc=test_acc)
                train_data.append(epoch_data)
                if val_acc > best_acc[1]:
                    best_acc = (test_acc, val_acc)
                    best_model = self.model
        print("Best Acc: {:.3f}\n".format(best_acc[0]))
        best_model.save_checkpoint(self.opt["checkpoint_save_path"], "best.pt")
        result_data["config"] = cfg_g2t
        result_data["train_data"] = train_data
        with open(experiment_result_file, "w") as f:
            json.dump(result_data, f, indent=4)
        return best_acc

    def eval(self, model, mode="val"):
        model.eval()
        reference_list = []
        candidate_list = []
        data_loader = self.test_data_loader if mode == "test" else self.valid_data_loader
        for data in tqdm(data_loader, desc="Eval: "):
            eval_input_graph, _, batch_original_tree_list = (
                data["graph_data"],
                data["dec_tree_batch"],
                data["original_dec_tree_batch"],
            )
            eval_input_graph = eval_input_graph.to(self.device)
            oov_dict = prepare_ext_vocab(eval_input_graph, self.src_vocab, self.device)

            if self.use_copy:
                assert len(batch_original_tree_list) == 1
                reference = oov_dict.get_symbol_idx_for_list(batch_original_tree_list[0].split())
                eval_vocab = oov_dict
            else:
                assert len(batch_original_tree_list) == 1
                reference = model.tgt_vocab.get_symbol_idx_for_list(
                    batch_original_tree_list[0].split()
                )
                eval_vocab = self.tgt_vocab

            candidate = model.translate(
                eval_input_graph,
                oov_dict=oov_dict,
                use_beam_search=True,
                beam_size=self.opt["beam_size"],
            )

            candidate = [int(c) for c in candidate]
            num_left_paren = sum(1 for c in candidate if eval_vocab.idx2symbol[int(c)] == "(")
            num_right_paren = sum(1 for c in candidate if eval_vocab.idx2symbol[int(c)] == ")")
            diff = num_left_paren - num_right_paren
            if diff > 0:
                for _ in range(diff):
                    candidate.append(self.test_data_loader.tgt_vocab.symbol2idx[")"])
            elif diff < 0:
                candidate = candidate[:diff]
            # ref_str = convert_to_string(reference, eval_vocab)
            # cand_str = convert_to_string(candidate, eval_vocab)

            reference_list.append(reference)
            candidate_list.append(candidate)
        eval_acc = compute_tree_accuracy(candidate_list, reference_list, eval_vocab)
        print("{} accuracy = {:.3f}\n".format(mode, eval_acc))
        return eval_acc

In [4]:
!rm -r ./data/processed/*

In [5]:
a = Mawps(cfg_g2t)



In [6]:
best_acc = a.train()

-------------
Starting training.


Epoch 01:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 1, train_loss = 31.375


Epoch 02:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 2, train_loss = 16.825


Epoch 03:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 3, train_loss = 12.983


Epoch 04:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 4, train_loss = 10.279


Epoch 05:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 5, train_loss = 8.828


Epoch 06:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 6, train_loss = 7.717


Epoch 07:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 7, train_loss = 6.898


Epoch 08:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 8, train_loss = 6.275


Epoch 09:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 9, train_loss = 5.724


Epoch 10:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 10, train_loss = 5.335


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.296



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.324



Epoch 11:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 11, train_loss = 4.997


Epoch 12:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 12, train_loss = 4.528


Epoch 13:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 13, train_loss = 4.351


Epoch 14:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 14, train_loss = 3.984


Epoch 15:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 15, train_loss = 3.635


Epoch 16:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 16, train_loss = 3.539


Epoch 17:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 17, train_loss = 3.659


Epoch 18:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 18, train_loss = 3.536


Epoch 19:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 19, train_loss = 3.322


Epoch 20:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 20, train_loss = 3.167


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.352



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.376



Epoch 21:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 21, train_loss = 3.069


Epoch 22:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 22, train_loss = 2.825


Epoch 23:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 23, train_loss = 2.629


Epoch 24:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 24, train_loss = 2.466


Epoch 25:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 25, train_loss = 2.369


Epoch 26:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 26, train_loss = 2.162


Epoch 27:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 27, train_loss = 1.977


Epoch 28:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 28, train_loss = 1.854


Epoch 29:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 29, train_loss = 1.754


Epoch 30:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 30, train_loss = 1.623


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.428



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.416



Epoch 31:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 31, train_loss = 1.575


Epoch 32:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 32, train_loss = 1.440


Epoch 33:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 33, train_loss = 1.345


Epoch 34:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 34, train_loss = 1.242


Epoch 35:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 35, train_loss = 1.227


Epoch 36:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 36, train_loss = 1.172


Epoch 37:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 37, train_loss = 1.252


Epoch 38:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 38, train_loss = 1.106


Epoch 39:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 39, train_loss = 0.893


Epoch 40:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 40, train_loss = 0.879


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.448



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.452



Epoch 41:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 41, train_loss = 0.801


Epoch 42:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 42, train_loss = 0.753


Epoch 43:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 43, train_loss = 0.671


Epoch 44:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 44, train_loss = 0.640


Epoch 45:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 45, train_loss = 0.625


Epoch 46:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 46, train_loss = 0.567


Epoch 47:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 47, train_loss = 0.520


Epoch 48:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 48, train_loss = 0.497


Epoch 49:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 49, train_loss = 0.458


Epoch 50:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 50, train_loss = 0.389


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.484



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.492



Epoch 51:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 51, train_loss = 0.381


Epoch 52:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 52, train_loss = 0.332


Epoch 53:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 53, train_loss = 0.322


Epoch 54:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 54, train_loss = 0.319


Epoch 55:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 55, train_loss = 0.256


Epoch 56:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 56, train_loss = 0.274


Epoch 57:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 57, train_loss = 0.257


Epoch 58:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 58, train_loss = 0.234


Epoch 59:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 59, train_loss = 0.190


Epoch 60:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 60, train_loss = 0.198


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.512



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.464



Epoch 61:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 61, train_loss = 0.222


Epoch 62:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 62, train_loss = 0.175


Epoch 63:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 63, train_loss = 0.150


Epoch 64:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 64, train_loss = 0.104


Epoch 65:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 65, train_loss = 0.132


Epoch 66:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 66, train_loss = 0.144


Epoch 67:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 67, train_loss = 0.139


Epoch 68:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 68, train_loss = 0.141


Epoch 69:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 69, train_loss = 0.117


Epoch 70:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 70, train_loss = 0.075


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.472



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.448



Epoch 71:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 71, train_loss = 0.072


Epoch 72:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 72, train_loss = 0.034


Epoch 73:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 73, train_loss = -0.009


Epoch 74:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 74, train_loss = 0.016


Epoch 75:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 75, train_loss = -0.008


Epoch 76:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 76, train_loss = -0.008


Epoch 77:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 77, train_loss = 0.074


Epoch 78:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 78, train_loss = 0.055


Epoch 79:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 79, train_loss = 0.065


Epoch 80:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 80, train_loss = 0.056


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.492



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.448



Epoch 81:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 81, train_loss = 0.021


Epoch 82:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 82, train_loss = 0.018


Epoch 83:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 83, train_loss = 0.004


Epoch 84:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 84, train_loss = 0.009


Epoch 85:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 85, train_loss = 0.019


Epoch 86:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 86, train_loss = -0.010


Epoch 87:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 87, train_loss = -0.028


Epoch 88:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 88, train_loss = -0.044


Epoch 89:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 89, train_loss = -0.029


Epoch 90:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 90, train_loss = -0.032


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.484



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.432



Epoch 91:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 91, train_loss = -0.046


Epoch 92:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 92, train_loss = -0.082


Epoch 93:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 93, train_loss = -0.088


Epoch 94:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 94, train_loss = -0.091


Epoch 95:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 95, train_loss = -0.056


Epoch 96:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 96, train_loss = -0.056


Epoch 97:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 97, train_loss = -0.033


Epoch 98:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 98, train_loss = -0.030


Epoch 99:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 99, train_loss = -0.042


Epoch 100:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 100, train_loss = -0.065


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.476



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.452



Epoch 101:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 101, train_loss = -0.088


Epoch 102:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 102, train_loss = -0.045


Epoch 103:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 103, train_loss = -0.080


Epoch 104:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 104, train_loss = -0.096


Epoch 105:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 105, train_loss = -0.090


Epoch 106:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 106, train_loss = -0.123


Epoch 107:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 107, train_loss = -0.078


Epoch 108:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 108, train_loss = -0.070


Epoch 109:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 109, train_loss = -0.107


Epoch 110:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 110, train_loss = -0.108


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.504



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.468



Epoch 111:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 111, train_loss = -0.090


Epoch 112:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 112, train_loss = -0.136


Epoch 113:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 113, train_loss = -0.112


Epoch 114:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 114, train_loss = -0.118


Epoch 115:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 115, train_loss = -0.085


Epoch 116:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 116, train_loss = -0.123


Epoch 117:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 117, train_loss = -0.099


Epoch 118:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 118, train_loss = -0.067


Epoch 119:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 119, train_loss = -0.084


Epoch 120:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 120, train_loss = -0.066


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.476



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.456



Epoch 121:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 121, train_loss = -0.061


Epoch 122:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 122, train_loss = -0.086


Epoch 123:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 123, train_loss = -0.091


Epoch 124:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 124, train_loss = -0.088


Epoch 125:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 125, train_loss = -0.090


Epoch 126:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 126, train_loss = -0.070


Epoch 127:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 127, train_loss = -0.037


Epoch 128:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 128, train_loss = -0.070


Epoch 129:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 129, train_loss = 0.021


Epoch 130:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 130, train_loss = 0.000


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.496



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.480



Epoch 131:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 131, train_loss = -0.019


Epoch 132:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 132, train_loss = -0.061


Epoch 133:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 133, train_loss = -0.078


Epoch 134:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 134, train_loss = -0.091


Epoch 135:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 135, train_loss = -0.100


Epoch 136:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 136, train_loss = -0.104


Epoch 137:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 137, train_loss = -0.002


Epoch 138:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 138, train_loss = 0.008


Epoch 139:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 139, train_loss = -0.048


Epoch 140:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 140, train_loss = -0.065


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.492



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.464



Epoch 141:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 141, train_loss = -0.096


Epoch 142:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 142, train_loss = -0.157


Epoch 143:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 143, train_loss = -0.122


Epoch 144:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 144, train_loss = -0.087


Epoch 145:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 145, train_loss = -0.110


Epoch 146:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 146, train_loss = -0.109


Epoch 147:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 147, train_loss = -0.126


Epoch 148:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 148, train_loss = -0.151


Epoch 149:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 149, train_loss = -0.129


Epoch 150:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 150, train_loss = -0.108


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.480



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.440



Epoch 151:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 151, train_loss = -0.135


Epoch 152:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 152, train_loss = -0.177


Epoch 153:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 153, train_loss = -0.159


Epoch 154:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 154, train_loss = -0.190


Epoch 155:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 155, train_loss = -0.153


Epoch 156:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 156, train_loss = -0.128


Epoch 157:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 157, train_loss = -0.130


Epoch 158:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 158, train_loss = -0.162


Epoch 159:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 159, train_loss = -0.147


Epoch 160:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 160, train_loss = -0.200


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.496



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.464



Epoch 161:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 161, train_loss = -0.191


Epoch 162:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 162, train_loss = -0.206


Epoch 163:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 163, train_loss = -0.192


Epoch 164:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 164, train_loss = -0.209


Epoch 165:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 165, train_loss = -0.201


Epoch 166:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 166, train_loss = -0.126


Epoch 167:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 167, train_loss = -0.153


Epoch 168:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 168, train_loss = -0.194


Epoch 169:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 169, train_loss = -0.209


Epoch 170:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 170, train_loss = -0.206


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.500



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.476



Epoch 171:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 171, train_loss = -0.195


Epoch 172:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 172, train_loss = -0.181


Epoch 173:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 173, train_loss = -0.205


Epoch 174:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 174, train_loss = -0.175


Epoch 175:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 175, train_loss = -0.156


Epoch 176:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 176, train_loss = -0.172


Epoch 177:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 177, train_loss = -0.179


Epoch 178:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 178, train_loss = -0.178


Epoch 179:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 179, train_loss = -0.163


Epoch 180:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 180, train_loss = -0.143


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.492



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.460



Epoch 181:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 181, train_loss = -0.132


Epoch 182:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 182, train_loss = -0.147


Epoch 183:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 183, train_loss = -0.143


Epoch 184:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 184, train_loss = -0.165


Epoch 185:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 185, train_loss = -0.163


Epoch 186:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 186, train_loss = -0.153


Epoch 187:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 187, train_loss = -0.175


Epoch 188:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 188, train_loss = -0.225


Epoch 189:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 189, train_loss = -0.228


Epoch 190:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 190, train_loss = -0.225


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.512



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.492



Epoch 191:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 191, train_loss = -0.219


Epoch 192:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 192, train_loss = -0.176


Epoch 193:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 193, train_loss = -0.214


Epoch 194:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 194, train_loss = -0.216


Epoch 195:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 195, train_loss = -0.194


Epoch 196:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 196, train_loss = -0.204


Epoch 197:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 197, train_loss = -0.202


Epoch 198:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 198, train_loss = -0.198


Epoch 199:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 199, train_loss = -0.212


Epoch 200:   0%|          | 0/24 [00:00<?, ?it/s]

epochs = 200, train_loss = -0.148


Eval:   0%|          | 0/250 [00:00<?, ?it/s]

test accuracy = 0.500



Eval:   0%|          | 0/250 [00:00<?, ?it/s]

val accuracy = 0.472

Best Acc: 0.484

