In [1]:
pip install sacrebleu

Collecting sacrebleu
  Downloading sacrebleu-2.4.0-py3-none-any.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.4/57.4 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting portalocker (from sacrebleu)
  Downloading portalocker-2.8.2-py3-none-any.whl.metadata (8.5 kB)
Downloading sacrebleu-2.4.0-py3-none-any.whl (106 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.3/106.3 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading portalocker-2.8.2-py3-none-any.whl (17 kB)
Installing collected packages: portalocker, sacrebleu
Successfully installed portalocker-2.8.2 sacrebleu-2.4.0
Note: you may need to restart the kernel to use updated packages.


In [15]:
import logging
import os
import pathlib
import random
import re
import sys
sys.path.append('/kaggle/input/sentence-compression')
import click
import sacrebleu

import torch

import torch.nn as nn
import tqdm

import config
import corpora

import utils
from main_fix import Network as FixNN
from main_sent import Network as ComNN
from sklearn.metrics import classification_report, precision_recall_fscore_support

In [None]:
# Get the current working directory
cwd = os.getcwd()

print(cwd)


In [16]:
cwd = os.path.dirname('/kaggle/working/')
logger = logging.getLogger("main")

In [17]:
class Network(nn.Module):
    def __init__(
        self, word2index, embeddings, prior,
    ):
        super().__init__()
        self.logger = logging.getLogger(f"{__name__}")
        self.word2index = word2index
        self.index2word = {i: k for k, i in word2index.items()}
        self.fix_gen = FixNN(
            embedding_type="glove",
            vocab_size=len(word2index),
            embedding_dim=config.embedding_dim,
            
            embeddings=embeddings,
            dropout=config.fix_dropout,
            hidden_dim=config.fix_hidden_dim,
        )
        self.com_nn = ComNN(
            embeddings=embeddings, hidden_size=config.sem_hidden_dim, prior=prior, device=config.DEV
        )

    def forward(self, x, target, seq_lens):
        x1 = nn.utils.rnn.pad_sequence(x, batch_first=True)
        target = nn.utils.rnn.pad_sequence(target, batch_first=True, padding_value=-1)

        fixations = torch.sigmoid(self.fix_gen(x1, seq_lens))
        # fixations = None

        loss, pred, atts = self.com_nn(x1, target, fixations)
        return loss, pred, atts, fixations

In [18]:
def load_corpus(corpus_name, splits):
    if not splits:
        return

    logger.info("loading corpus")
    if corpus_name == "google":
        load_fn = corpora.load_google

    corpus = {}
    langs = []

    if "train" in splits:
        train_pairs, train_lang = load_fn("train", max_len=200)
        corpus["train"] = train_pairs
        langs.append(train_lang)
    if "val" in splits:
        val_pairs, val_lang = load_fn("val")
        corpus["val"] = val_pairs
        langs.append(val_lang)
    if "test" in splits:
        test_pairs, test_lang = load_fn("test")
        corpus["test"] = test_pairs
        langs.append(test_lang)

    logger.info("creating word index")
    lang = langs[0]
    for _lang in langs[1:]:
        lang += _lang
    word2index = lang.word2index

    index2word = {i: w for w, i in word2index.items()}

    return corpus, word2index, index2word

In [19]:
def init_network(word2index, prior):
    logger.info("loading embeddings")
    vocabulary = sorted(word2index.keys())
    embeddings = utils.load_glove(vocabulary)

    logger.info("initializing model")
    network = Network(word2index=word2index, embeddings=embeddings, prior=prior)
    network.to(config.DEV)

    print(f"#parameters: {sum(p.numel() for p in network.parameters())}")

    return network

In [20]:
@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
@click.option("-v", "--verbose", count=True)
@click.option("-d", "--debug", is_flag=True)
def main(verbose, debug):
    if verbose == 0:
        loglevel = logging.ERROR
    elif verbose == 1:
        loglevel = logging.WARN
    elif verbose >= 2:
        loglevel = logging.INFO

    if debug:
        loglevel = logging.DEBUG

    logging.basicConfig(
        format="[%(asctime)s] <%(name)s> %(levelname)s: %(message)s",
        datefmt="%d.%m. %H:%M:%S",
        level=loglevel,
    )

    logger.debug("arguments: %s" % str(sys.argv))

In [21]:
def train(corpus_name, model_name, fixation_weights=None, freeze_fixations=False, debug=False, prior=0.5):
    corpus, word2index, index2word = load_corpus(corpus_name, ["train", "val"])
    train_pairs = corpus["train"]
    val_pairs = corpus["val"]
    network = init_network(word2index, prior)

    model_dir = os.path.join("models", model_name)
    logger.debug("creating model dir %s" % model_dir)
    pathlib.Path(model_dir).mkdir(parents=True)

    if fixation_weights is not None:
        logger.info("loading fixation prediction checkpoint")
        checkpoint = torch.load(fixation_weights, map_location=config.DEV)
        if "word2index" in checkpoint:
            weights = checkpoint["weights"]
        else:
            weights = checkpoint

        # remove the embedding layer before loading
        weights = {
            k: v for k, v in weights.items() if not k.startswith("pre.embedding_layer")
        }
        network.fix_gen.load_state_dict(weights, strict=False)

        if freeze_fixations:
            logger.info("freezing fixation generation network")
            for p in network.fix_gen.parameters():
                p.requires_grad = False

    optimizer = torch.optim.Adam(network.parameters(), lr=config.learning_rate)


    epoch = 3
    # Load previously saved model and word2index
    checkpoint_path = '/kaggle/input/sentence-compression/sent_3 (1).tar'
    checkpoint = torch.load(checkpoint_path, map_location=config.DEV)
    network.load_state_dict(checkpoint["weights"])

    word2index = checkpoint["word2index"]
    best_val_loss = 2.5685
    batch_size = 5

    while epoch<6:
        train_batch_iter = utils.sent_iter(
            sents=train_pairs, word2index=word2index, batch_size=batch_size
        )
        val_batch_iter = utils.sent_iter(
            sents=val_pairs, word2index=word2index, batch_size=batch_size
        )

        total_train_loss = 0
        total_val_loss = 0

        network.train()
        for i, batch in tqdm.tqdm(
            enumerate(train_batch_iter, 1), total=len(train_pairs) // batch_size + 1
        ):
            optimizer.zero_grad()

            raw_sent, sent, target = batch
            seq_lens = [len(x) for x in sent]
            loss, prediction, attention, fixations = network(sent, target, seq_lens)

            prediction = prediction.detach().cpu().numpy()

            torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)

            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_pairs)

        val_sents = []
        val_preds = []
        val_targets = []

        network.eval()
        for i, batch in tqdm.tqdm(
            enumerate(val_batch_iter), total=len(val_pairs) // batch_size + 1
        ):
            raw_sent, sent, target = batch
            seq_lens = [len(x) for x in sent]
            loss, prediction, attention, fixations = network(sent, target, seq_lens)

            prediction = prediction.detach().cpu().numpy()

            for i, l in enumerate(seq_lens):
                val_sents.append(raw_sent[i][:l])
                val_preds.append(prediction[i][:l].tolist())
                val_targets.append(target[i][:l].tolist())

            total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_pairs)
        
        # Calculate compression ratio
        compression_ratios = [len(sent) / pred.count(1) if 1 in pred else 0 for sent, pred in zip(val_sents, val_preds)]


        print(
            f"epoch {epoch} train_loss {avg_train_loss:.4f} val_loss {avg_val_loss:.4f}"
        )

        print(
            classification_report(
                [x for y in val_targets for x in y],
                [x for y in val_preds for x in y],
                target_names=["not_del", "del"],
                digits=5,
            )
        )
        
        # Print the classification report and compression ratio
        print(f"Avg Compression Ratio: {sum(compression_ratios) / len(compression_ratios):.4f}")

        with open(f"models/{model_name}/val_original_{epoch}.txt", "w") as oh, open(
            f"models/{model_name}/val_pred_{epoch}.txt", "w"
        ) as ph, open(f"models/{model_name}/val_gold_{epoch}.txt", "w") as gh:
            for sent, preds, golds in zip(val_sents, val_preds, val_targets):
                pred_compressed = [
                    word for word, delete in zip(sent, preds) if not delete
                ]
                gold_compressed = [
                    word for word, delete in zip(sent, golds) if not delete
                ]

                oh.write(" ".join(sent))
                ph.write(" ".join(pred_compressed))
                gh.write(" ".join(gold_compressed))

                oh.write("\n")
                ph.write("\n")
                gh.write("\n")

        if best_val_loss is None or avg_val_loss < best_val_loss:
            delta = avg_val_loss - best_val_loss if best_val_loss is not None else 0.0
            best_val_loss = avg_val_loss
            print(
                f"new best model epoch {epoch} val loss {avg_val_loss:.4f} ({delta:.4f})"
            )

        utils.save_model(
            network, word2index, f"models/{model_name}/{model_name}_{epoch}"
        )

        epoch += 1


In [22]:
@main.command()
@click.option(
    "-c",
    "--corpus",
    
    "corpus_name",
    required=True,
    type=click.Choice(sorted(["google",])),
)
@click.option("-w", "--model_weights", required=True)
@click.option("-p", "--prior", type=float, default=.5)
@click.option("-l", "--longest", is_flag=True)
@click.option("-s", "--shortest", is_flag=True)
@click.option("-d", "--detailed", is_flag=True)
def test(corpus_name, model_weights, prior, longest, shortest, detailed):
    if longest and shortest:
        print("longest and shortest are mutually exclusive", file=sys.stderr)
        sys.exit()

    corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
    test_pairs = corpus["test"]

    model_name = os.path.basename(os.path.dirname(model_weights))
    epoch = re.search("_(\d+).tar", model_weights).group(1)

    logger.info("loading model checkpoint")
    checkpoint = torch.load(model_weights, map_location=config.DEV)
    if "word2index" in checkpoint:
        weights = checkpoint["weights"]
        word2index = checkpoint["word2index"]
        index2word = {i: w for w, i in word2index.items()}
    else:
        asdf

    network = init_network(word2index, prior)
    network.eval()

    # remove the embedding layer before loading
    # weights = {k: v for k, v in weights.items() if not "embedding" in k}
    # actually load the parameters
    network.load_state_dict(weights, strict=False)

    total_test_loss = 0

    batch_size = 20

    test_batch_iter = utils.sent_iter(
        sents=test_pairs, word2index=word2index, batch_size=batch_size
    )

    test_sents = []
    test_preds = []
    test_targets = []

    for i, batch in tqdm.tqdm(
        enumerate(test_batch_iter, 1), total=len(test_pairs) // batch_size + 1
    ):
        raw_sent, sent, target = batch
        seq_lens = [len(x) for x in sent]
        loss, prediction, attention, fixations = network(sent, target, seq_lens)

        prediction = prediction.detach().cpu().numpy()

        for i, l in enumerate(
            seq_lens
        ):
            test_sents.append(raw_sent[i][:l])
            test_preds.append(prediction[i][:l].tolist())
            test_targets.append(target[i][:l].tolist())

        total_test_loss += loss.item()

    avg_test_loss = total_test_loss / len(test_pairs)

    print(f"test_loss {avg_test_loss:.4f}")

    if longest:
        avg_len = sum(len(s) for s in test_sents)/len(test_sents)
        test_sents = list(filter(lambda x: len(x) > avg_len, test_sents))
        test_preds = list(filter(lambda x: len(x) > avg_len, test_preds))
        test_targets = list(filter(lambda x: len(x) > avg_len, test_targets))
    elif shortest:
        avg_len = sum(len(s) for s in test_sents)/len(test_sents)
        test_sents = list(filter(lambda x: len(x) <= avg_len, test_sents))
        test_preds = list(filter(lambda x: len(x) <= avg_len, test_preds))
        test_targets = list(filter(lambda x: len(x) <= avg_len, test_targets))

    if detailed:
        for test_sent, test_target, test_pred in zip(test_sents, test_targets, test_preds):
            print(precision_recall_fscore_support(test_target, test_pred, average="weighted")[2], test_sent, test_target, test_pred)
    else:
        print(
            classification_report(
                [x for y in test_targets for x in y],
                [x for y in test_preds for x in y],
                target_names=["not_del", "del"],
                digits=5,
            )
        )

    with open(f"models/{model_name}/test_original_{epoch}.txt", "w") as oh, open(
        f"models/{model_name}/test_pred_{epoch}.txt", "w"
    ) as ph, open(f"models/{model_name}/test_gold_{epoch}.txt", "w") as gh:
        for sent, preds, golds in zip(test_sents, test_preds, test_targets):
            pred_compressed = [word for word, delete in zip(sent, preds) if not delete]
            gold_compressed = [word for word, delete in zip(sent, golds) if not delete]

            oh.write(" ".join(sent))
            ph.write(" ".join(pred_compressed))
            gh.write(" ".join(gold_compressed))

            oh.write("\n")
            ph.write("\n")
            gh.write("\n")

In [23]:
def predict(corpus_name, model_weights, prior=0.5, longest=True, shortest=False):
    if longest and shortest:
        print("longest and shortest are mutually exclusive", file=sys.stderr)
        sys.exit()

    corpus, word2index, index2word = load_corpus(corpus_name, ["test"])
    test_pairs = corpus["test"]

    model_name = os.path.basename(os.path.dirname(model_weights))
    epoch = 8

    logger.info("loading model checkpoint")
    checkpoint = torch.load(model_weights, map_location=config.DEV)
    if "word2index" in checkpoint:
        weights = checkpoint["weights"]
        word2index = checkpoint["word2index"]
        index2word = {i: w for w, i in word2index.items()}
    else:
        asdf

    network = init_network(word2index, prior)
    network.eval()

    # remove the embedding layer before loading
    # weights = {k: v for k, v in weights.items() if not "embedding" in k}
    # actually load the parameters
    network.load_state_dict(weights, strict=False)

    total_test_loss = 0

    batch_size = 20

    test_batch_iter = utils.sent_iter(
        sents=test_pairs, word2index=word2index, batch_size=batch_size
    )

    test_sents = []
    test_preds = []
    test_attentions = []
    test_fixations = []

    for i, batch in tqdm.tqdm(
        enumerate(test_batch_iter, 1), total=len(test_pairs) // batch_size + 1
    ):
        raw_sent, sent, target = batch
        seq_lens = [len(x) for x in sent]
        loss, prediction, attention, fixations = network(sent, target, seq_lens)

        prediction = prediction.detach().cpu().numpy()
        attention = attention.detach().cpu().numpy()
        if fixations is not None:
            fixations = fixations.detach().cpu().numpy()

        for i, l in enumerate(
            seq_lens
        ):
            test_sents.append(raw_sent[i][:l])
            test_preds.append(prediction[i][:l].tolist())
            test_attentions.append(attention[i][:l].tolist())
            if fixations is not None:
                test_fixations.append(fixations[i][:l].tolist())
            else:
                test_fixations.append([])

        total_test_loss += loss.item()

    avg_test_loss = total_test_loss / len(test_pairs)

    if longest:
        avg_len = sum(len(s) for s in test_sents)/len(test_sents)
        test_sents = list(filter(lambda x: len(x) > avg_len, test_sents))
        test_preds = list(filter(lambda x: len(x) > avg_len, test_preds))
        test_attentions = list(filter(lambda x: len(x) > avg_len, test_attentions))
        test_fixations = list(filter(lambda x: len(x) > avg_len, test_fixations))
    elif shortest:
        avg_len = sum(len(s) for s in test_sents)/len(test_sents)
        test_sents = list(filter(lambda x: len(x) <= avg_len, test_sents))
        test_preds = list(filter(lambda x: len(x) <= avg_len, test_preds))
        test_attentions = list(filter(lambda x: len(x) <= avg_len, test_attentions))
        test_fixations = list(filter(lambda x: len(x) <= avg_len, test_fixations))

    print(f"sentence\tprediction\tattentions\tfixations")
    for s, p, a, f in zip(test_sents, test_preds, test_attentions, test_fixations):
        a = [x[:len(a)] for x in a]
        print(f"{s}\t{p}\t{a}\t{f}")

In [13]:
train('google','sent')

128 4
#parameters: 147002715


  torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
100%|██████████| 8897/8897 [50:58<00:00,  2.91it/s]  
100%|█████████▉| 1000/1001 [02:08<00:00,  7.76it/s]


epoch 1 train_loss 1.7352 val_loss 1.5888
              precision    recall  f1-score   support

     not_del    0.82951   0.89309   0.86013    173095
         del    0.76864   0.65928   0.70977     93250

    accuracy                        0.81123    266345
   macro avg    0.79908   0.77619   0.78495    266345
weighted avg    0.80820   0.81123   0.80749    266345

new best model epoch 1 val loss 1.5888 (0.0000)


  torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
100%|██████████| 8897/8897 [50:57<00:00,  2.91it/s]  
100%|█████████▉| 1000/1001 [02:07<00:00,  7.87it/s]


epoch 2 train_loss 1.5240 val_loss 1.4919
              precision    recall  f1-score   support

     not_del    0.84243   0.90543   0.87280    173095
         del    0.79616   0.68564   0.73678     93250

    accuracy                        0.82848    266345
   macro avg    0.81930   0.79554   0.80479    266345
weighted avg    0.82623   0.82848   0.82518    266345

new best model epoch 2 val loss 1.4919 (-0.0969)


  torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
100%|██████████| 8897/8897 [50:51<00:00,  2.92it/s]  
100%|█████████▉| 1000/1001 [02:06<00:00,  7.92it/s]


epoch 3 train_loss 1.4095 val_loss 1.4265
              precision    recall  f1-score   support

     not_del    0.85407   0.90413   0.87839    173095
         del    0.80032   0.71324   0.75428     93250

    accuracy                        0.83730    266345
   macro avg    0.82720   0.80869   0.81633    266345
weighted avg    0.83525   0.83730   0.83494    266345

new best model epoch 3 val loss 1.4265 (-0.0654)


In [12]:
train('google','sent')

128 4
#parameters: 147002715


  torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
100%|█████████▉| 17793/17794 [1:02:48<00:00,  4.72it/s]
100%|█████████▉| 2000/2001 [02:35<00:00, 12.84it/s]


epoch 4 train_loss 2.4024 val_loss 2.4594
              precision    recall  f1-score   support

     not_del    0.86270   0.90822   0.88488    173095
         del    0.81114   0.73169   0.76937     93250

    accuracy                        0.84642    266345
   macro avg    0.83692   0.81996   0.82712    266345
weighted avg    0.84465   0.84642   0.84444    266345

Avg Compression Ratio: 1.0000


  torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
100%|█████████▉| 17793/17794 [1:02:41<00:00,  4.73it/s]
100%|█████████▉| 2000/2001 [02:33<00:00, 12.99it/s]


epoch 5 train_loss 2.2128 val_loss 2.4321
              precision    recall  f1-score   support

     not_del    0.86254   0.91486   0.88793    173095
         del    0.82190   0.72936   0.77287     93250

    accuracy                        0.84991    266345
   macro avg    0.84222   0.82211   0.83040    266345
weighted avg    0.84831   0.84991   0.84765    266345

Avg Compression Ratio: 1.0000


  torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
100%|█████████▉| 17793/17794 [1:02:49<00:00,  4.72it/s]
100%|█████████▉| 2000/2001 [02:41<00:00, 12.40it/s]


epoch 6 train_loss 2.0533 val_loss 2.4533
              precision    recall  f1-score   support

     not_del    0.86768   0.90992   0.88830    173095
         del    0.81617   0.74242   0.77755     93250

    accuracy                        0.85128    266345
   macro avg    0.84193   0.82617   0.83292    266345
weighted avg    0.84965   0.85128   0.84952    266345

Avg Compression Ratio: 1.0000


  torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
100%|█████████▉| 17793/17794 [1:03:41<00:00,  4.66it/s]
100%|█████████▉| 2000/2001 [02:42<00:00, 12.34it/s]


epoch 7 train_loss 1.8859 val_loss 2.5894
              precision    recall  f1-score   support

     not_del    0.86638   0.91056   0.88792    173095
         del    0.81662   0.73931   0.77605     93250

    accuracy                        0.85061    266345
   macro avg    0.84150   0.82494   0.83198    266345
weighted avg    0.84896   0.85061   0.84875    266345

Avg Compression Ratio: 1.0000


  torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
  3%|▎         | 496/17794 [01:46<1:01:38,  4.68it/s]


KeyboardInterrupt: 

In [12]:
train('google','sent')

128 4
#parameters: 147002715


  torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
100%|█████████▉| 17793/17794 [1:01:47<00:00,  4.80it/s]
100%|█████████▉| 2000/2001 [02:29<00:00, 13.37it/s]


epoch 8 train_loss 1.7026 val_loss 2.7080
              precision    recall  f1-score   support

     not_del    0.87317   0.90193   0.88732    173095
         del    0.80611   0.75683   0.78069     93250

    accuracy                        0.85113    266345
   macro avg    0.83964   0.82938   0.83401    266345
weighted avg    0.84969   0.85113   0.84999    266345

Avg Compression Ratio: 3.3871


In [None]:
train('google','sent')

128 4
#parameters: 147002715


  torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
100%|█████████▉| 17793/17794 [1:30:37<00:00,  3.27it/s]  
100%|█████████▉| 2000/2001 [03:29<00:00,  9.54it/s]


epoch 1 train_loss 3.0409 val_loss 2.7835
              precision    recall  f1-score   support

     not_del    0.84072   0.89385   0.86647    173095
         del    0.77677   0.68564   0.72837     93250

    accuracy                        0.82095    266345
   macro avg    0.80874   0.78975   0.79742    266345
weighted avg    0.81833   0.82095   0.81812    266345

Avg Compression Ratio: 3.5310
new best model epoch 1 val loss 2.7835 (0.0000)


  torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
100%|█████████▉| 17793/17794 [1:30:45<00:00,  3.27it/s]  
100%|█████████▉| 2000/2001 [03:30<00:00,  9.52it/s]


epoch 2 train_loss 2.6296 val_loss 2.5685
              precision    recall  f1-score   support

     not_del    0.85364   0.90658   0.87931    173095
         del    0.80402   0.71146   0.75492     93250

    accuracy                        0.83827    266345
   macro avg    0.82883   0.80902   0.81711    266345
weighted avg    0.83627   0.83827   0.83576    266345

Avg Compression Ratio: 3.5496
new best model epoch 2 val loss 2.5685 (-0.2150)


  torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
 96%|█████████▌| 17055/17794 [1:26:59<04:00,  3.07it/s]  

In [None]:
train('google','sent')

128 4
#parameters: 147002715


  torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
100%|█████████▉| 17793/17794 [1:32:12<00:00,  3.22it/s]  
100%|█████████▉| 2000/2001 [03:34<00:00,  9.33it/s]


epoch 3 train_loss 2.2247 val_loss 2.4341
              precision    recall  f1-score   support

     not_del    0.86696   0.90984   0.88788    173095
         del    0.81573   0.74084   0.77648     93250

    accuracy                        0.85067    266345
   macro avg    0.84134   0.82534   0.83218    266345
weighted avg    0.84902   0.85067   0.84888    266345

Avg Compression Ratio: 3.4641
new best model epoch 3 val loss 2.4341 (-0.1344)


  torch.nn.utils.clip_grad_norm(network.parameters(), max_norm=5)
 62%|██████▏   | 11019/17794 [57:22<31:05,  3.63it/s]  