In [1]:
#!pip install datasets
#!pip install transformers


from datasets import load_dataset
from torch.utils.data import DataLoader
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import copy
from tqdm import tqdm_notebook
import os
from torch.optim.lr_scheduler import ReduceLROnPlateau
import warnings

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
# сюда нужно поместить ваш токен который дает доступ к гиту, его можно сделать в настройках гита
!git clone https://<token>@github.com/EkaterinaAdishcheva/LinAlg_Project_AI_Master.git

Cloning into 'LinAlg_Project_AI_Master'...
remote: Enumerating objects: 2793, done.[K
remote: Counting objects: 100% (2793/2793), done.[K
remote: Compressing objects: 100% (1811/1811), done.[K
remote: Total 2793 (delta 1145), reused 2608 (delta 980), pack-reused 0[K
Receiving objects: 100% (2793/2793), 5.47 MiB | 5.55 MiB/s, done.
Resolving deltas: 100% (1145/1145), done.


In [2]:
from LinAlg_Project_AI_Master.src.transformers_my.models.bert.modeling_bert import (
    BertForMaskedLM,
)
from LinAlg_Project_AI_Master.src.transformers_my.models.bert.tokenization_bert import (
    BertTokenizer,
)

In [3]:
BertForMaskedLM

LinAlg_Project_AI_Master.src.transformers_my.models.bert.modeling_bert.BertForMaskedLM

In [4]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased", return_dict=True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
cola = load_dataset("glue", "cola", split="train")



In [6]:
class MaskedCola(object):
    def __init__(self, dataset, tokenizer):
        self.df = dataset
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        # text = self.df[idx]['sentence'].lower()
        text = self.df["sentence"][idx].lower()

        inputs = self.tokenizer.encode_plus(
            text,  # Sentence to encode.
            add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
            max_length=64,  # Pad & truncate all sentences.
            pad_to_max_length=True,
            return_attention_mask=True,  # Construct attn. masks.
            return_tensors="pt",  # Return pytorch tensors.
            truncation=True,  # Truncate everything more than max len
        )
        input_ids = inputs["input_ids"]

        labels = copy.deepcopy(input_ids)
        masked_idx = np.random.randint(sum(input_ids[0] != 0))

        input_ids[0][masked_idx] = tokenizer.mask_token_id

        labels[input_ids != tokenizer.mask_token_id] = -100

        # input_ids = input_ids, attention_mask = inputs['attention_mask'] , token_type_ids=inputs['token_type_ids'] , labels=labels

        return {
            "input_ids": input_ids,
            "attention_mask": inputs["attention_mask"],
            "token_type_ids": inputs["token_type_ids"],
            "labels": labels,
        }

    def __len__(self):
        return len(self.df["sentence"])


def collate_fn(batch):
    iid_ls = []
    atm_ls = []
    tti_ls = []
    l_ls = []

    for elem in batch:
        iid = elem["input_ids"]
        atm = elem["attention_mask"]
        tti = elem["token_type_ids"]
        l = elem["labels"]

        iid_ls.append(iid)
        atm_ls.append(atm)
        tti_ls.append(tti)
        l_ls.append(l)

    return torch.cat(iid_ls), torch.cat(atm_ls), torch.cat(tti_ls), torch.cat(l_ls)

In [7]:
dataset_train = MaskedCola(cola[:7000], tokenizer)
dataset_val = MaskedCola(cola[7000:], tokenizer)

dataloader_train = DataLoader(
    dataset_train,
    batch_size=128,
    shuffle=False,
    sampler=None,
    collate_fn=collate_fn,
    drop_last=False,
)
dataloader_val = DataLoader(
    dataset_val,
    batch_size=128,
    shuffle=False,
    sampler=None,
    collate_fn=collate_fn,
    drop_last=False,
)

In [16]:
# Функции для сохранения и загрузки чекпоинтов
def save_checkpoint(model, filename):
    with open(filename, "wb") as fp:
        torch.save(model.state_dict(), fp)


def load_checkpoint(model, filename):
    with open(filename, "rb") as fp:
        state_dict = torch.load(fp)

    model.load_state_dict(state_dict)


# цикл для трейна
def run_epoch_train(
    model, dataloader, loss_fn, optimizer, epoch, device, scheduler=None, log_acc=False
):
    stage = "train"

    model.train()
    torch.set_grad_enabled(True)

    model = model.to(device)

    if log_acc:
        total_true = 0
        total_labels = 0

    losses = []
    for batch in tqdm_notebook(
        dataloader,
        total=len(dataloader),
        desc=f"epoch: {str(epoch).zfill(3)} | {stage:5}",
    ):
        inp_ids = batch[0].to(device)
        att_masks = batch[1].to(device)
        token_types = batch[2].to(device)
        labels = batch[3].to(device)

        out = model(
            input_ids=inp_ids,
            attention_mask=att_masks,
            token_type_ids=token_types,
            labels=labels,
        )

        loss = out["loss"]
        logits = out["logits"]

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.detach().cpu().item())

        if log_acc:
            mask_token_index = (inp_ids == tokenizer.mask_token_id).nonzero(
                as_tuple=True
            )[1]

            pred_tokens = logits[:, mask_token_index, :].argmax(axis=-1).diag()

            true_preds = labels[labels != -100] == pred_tokens

            total_true += true_preds.sum().item()
            total_labels += len(true_preds)

    if scheduler is not None:
        scheduler.step(np.mean(losses))

    if log_acc:
        acc = total_true / total_labels
        print(acc)

    return np.mean(losses)


# цикл для валидации
def run_epoch_val(
    model, dataloader, loss_fn, optimizer, epoch, device, scheduler=None, log_acc=False
):
    stage = "val"
    torch.set_grad_enabled(False)
    model.eval()

    if log_acc:
        total_true = 0
        total_labels = 0

    model = model.to(device)

    losses = []
    for batch in tqdm_notebook(
        dataloader,
        total=len(dataloader),
        desc=f"epoch: {str(epoch).zfill(3)} | {stage:5}",
    ):
        inp_ids = batch[0].to(device)
        att_masks = batch[1].to(device)
        token_types = batch[2].to(device)
        labels = batch[3].to(device)

        out = model(
            input_ids=inp_ids,
            attention_mask=att_masks,
            token_type_ids=token_types,
            labels=labels,
        )

        loss = out["loss"]
        logits = out["logits"]

        losses.append(loss.detach().cpu().item())

        if log_acc:
            mask_token_index = (inp_ids == tokenizer.mask_token_id).nonzero(
                as_tuple=True
            )[1]

            pred_tokens = logits[:, mask_token_index, :].argmax(axis=-1).diag()

            true_preds = labels[labels != -100] == pred_tokens

            total_true += true_preds.sum().item()
            total_labels += len(true_preds)

    if log_acc:
        acc = total_true / total_labels
        print(acc)

    return np.mean(losses)


# цикл для прохождения по эпохам, обучения, валидации и логирования
def run_experiment(
    model,
    dataloader_train,
    dataloader_val,
    loss_fn,
    optimizer,
    num_epochs,
    device,
    output_dir,
    scheduler,
    log_acc=True,
):
    train_losses = []
    val_losses = []

    best_val_loss = np.inf
    best_val_loss_epoch = -1
    best_val_loss_fn = None

    os.makedirs(output_dir, exist_ok=True)

    for epoch in range(num_epochs):
        train_loss = run_epoch_train(
            model,
            dataloader_train,
            loss_fn,
            optimizer,
            epoch,
            device,
            scheduler,
            log_acc=log_acc,
        )
        train_losses.append(train_loss)

        val_loss = run_epoch_val(
            model, dataloader_val, loss_fn, optimizer, epoch, device, log_acc=log_acc
        )
        val_losses.append(val_loss)

        print(
            f"epoch: {str(epoch).zfill(3)} | train_loss: {train_loss:5.3f}, val_loss: {val_loss:5.3f} (best: {best_val_loss:5.3f})"
        )

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_loss_epoch = epoch

            output_fn = os.path.join(
                output_dir,
                f"epoch={str(epoch).zfill(2)}_valloss={best_val_loss:.3f}.pth.tar",
            )
            save_checkpoint(model, output_fn)
            print(f"New checkpoint saved to {output_fn}")

            # best_val_loss_fn = output_fn

        print()

    print(f"Best val_loss = {best_val_loss:.3f} reached at epoch {best_val_loss_epoch}")
    # load_checkpoint(model, best_val_loss_fn)

    return train_losses, val_losses, best_val_loss, model


def validate(model, dataloader, device):
    model.eval()
    model = model.to(device)

    total_true = 0
    total_labels = 0

    for batch in tqdm_notebook(dataloader):
        inp_ids = batch[0].to(device)
        att_masks = batch[1].to(device)
        token_types = batch[2].to(device)
        labels = batch[3].to(device)

        out = model(
            input_ids=inp_ids,
            attention_mask=att_masks,
            token_type_ids=token_types,
            labels=labels,
        )
        logits = out["logits"]

        mask_token_index = (inp_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[
            1
        ]

        pred_tokens = logits[:, mask_token_index, :].argmax(axis=-1).diag()

        true_preds = labels[labels != -100] == pred_tokens

        total_true += true_preds.sum().item()
        total_labels += len(true_preds)

    return total_true / total_labels


def truncation_finetuning(
    model,
    dataloader_train,
    dataloader_val,
    lr,
    num_epochs,
    device,
    output_dir,
    compression_scheme,
):
    encoder_layer_size = 768

    r_ls = [
        int(encoder_layer_size * 1 / compress_ratio)
        for compress_ratio in compression_scheme
    ]

    acc_ls = []

    model.to(device)

    for compression_stage in range(len(compression_scheme)):
        # compressing
        r = r_ls[compression_stage]
        model.set_truncation(r)
        print(r)

        # freeze everything except embedding params
        """
        for param in model.parameters():
          #print(param)
          param.requires_grad = False

        for param in model.bert.embeddings.parameters():
          param.requires_grad = True
        """

        # defining params
        loss_fn = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = ReduceLROnPlateau(
            optimizer, mode="min", factor=0.1, patience=4, verbose=True, min_lr=3e-6
        )

        # fintuning
        (
            train_losses_baseline,
            val_losses_baseline,
            best_val_loss_baseline,
            cnn_baseline,
        ) = run_experiment(
            model,
            dataloader_train,
            dataloader_val,
            loss_fn,
            optimizer,
            num_epochs,
            device,
            output_dir,
            scheduler,
        )

        # validation + accuracy measure

        acc = validate(model, dataloader_val, device)

        acc_ls.append(acc)

        # decompressing
        W = (
            model.bert.embeddings.state_dict()["U"]
            * model.bert.embeddings.state_dict()["S"]
        ) @ model.bert.embeddings.state_dict()["Vt"]
        model.bert.embeddings.word_embeddings.weight.data = W

    return acc_ls

In [20]:
lr = 3e-5
num_epochs = 15

compression_scheme = [2, 4, 8, 10, 12]
output_dir = "Iter_compress_no_freez_2_4_8_10_12"

In [None]:
acc_ls = truncation_finetuning(
    model,
    dataloader_train,
    dataloader_val,
    lr,
    num_epochs,
    device,
    output_dir,
    compression_scheme,
)

Embeddings Truncated
384


epoch: 000 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.679


epoch: 000 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4261766602192134
epoch: 000 | train_loss: 1.643, val_loss: 4.829 (best:   inf)
New checkpoint saved to Iter_compress_no_freez_2_4_8_10_12/epoch=00_valloss=4.829.pth.tar



epoch: 001 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6785714285714286


epoch: 001 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4687298517085751
epoch: 001 | train_loss: 1.589, val_loss: 4.568 (best: 4.829)
New checkpoint saved to Iter_compress_no_freez_2_4_8_10_12/epoch=01_valloss=4.568.pth.tar



epoch: 002 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7022857142857143


epoch: 002 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4584139264990329
epoch: 002 | train_loss: 1.571, val_loss: 4.752 (best: 4.568)



epoch: 003 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6838571428571428


epoch: 003 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4526112185686654
epoch: 003 | train_loss: 1.575, val_loss: 4.885 (best: 4.568)



epoch: 004 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6927142857142857


epoch: 004 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.43584784010315925
epoch: 004 | train_loss: 1.532, val_loss: 4.646 (best: 4.568)



epoch: 005 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7035714285714286


epoch: 005 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4526112185686654
epoch: 005 | train_loss: 1.473, val_loss: 4.760 (best: 4.568)



epoch: 006 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.698


epoch: 006 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.47259832366215343
epoch: 006 | train_loss: 1.533, val_loss: 4.507 (best: 4.568)
New checkpoint saved to Iter_compress_no_freez_2_4_8_10_12/epoch=06_valloss=4.507.pth.tar



epoch: 007 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.701


epoch: 007 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.486782720825274
epoch: 007 | train_loss: 1.480, val_loss: 4.674 (best: 4.507)



epoch: 008 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7064285714285714


epoch: 008 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.463571889103804
epoch: 008 | train_loss: 1.456, val_loss: 4.586 (best: 4.507)



epoch: 009 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7142857142857143


epoch: 009 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.45454545454545453
epoch: 009 | train_loss: 1.431, val_loss: 4.662 (best: 4.507)



epoch: 010 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.714


epoch: 010 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.46421663442940037
epoch: 010 | train_loss: 1.448, val_loss: 4.615 (best: 4.507)



epoch: 011 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7117142857142857


epoch: 011 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.46034816247582205
epoch: 011 | train_loss: 1.413, val_loss: 4.691 (best: 4.507)



epoch: 012 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7178571428571429


epoch: 012 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.46808510638297873
epoch: 012 | train_loss: 1.388, val_loss: 4.724 (best: 4.507)



epoch: 013 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7197142857142858


epoch: 013 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.42940038684719534
epoch: 013 | train_loss: 1.389, val_loss: 4.622 (best: 4.507)



epoch: 014 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7287142857142858


epoch: 014 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.47453255963894264
epoch: 014 | train_loss: 1.360, val_loss: 4.442 (best: 4.507)
New checkpoint saved to Iter_compress_no_freez_2_4_8_10_12/epoch=14_valloss=4.442.pth.tar

Best val_loss = 4.442 reached at epoch 14


  0%|          | 0/13 [00:00<?, ?it/s]

Embeddings Truncated
192


epoch: 000 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6607142857142857


epoch: 000 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4345583494519665
epoch: 000 | train_loss: 1.703, val_loss: 4.691 (best:   inf)
New checkpoint saved to Iter_compress_no_freez_2_4_8_10_12/epoch=00_valloss=4.691.pth.tar



epoch: 001 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6711428571428572


epoch: 001 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4261766602192134
epoch: 001 | train_loss: 1.654, val_loss: 5.086 (best: 4.691)



epoch: 002 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6802857142857143


epoch: 002 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4629271437782076
epoch: 002 | train_loss: 1.611, val_loss: 4.675 (best: 4.691)
New checkpoint saved to Iter_compress_no_freez_2_4_8_10_12/epoch=02_valloss=4.675.pth.tar



epoch: 003 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.683


epoch: 003 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.44680851063829785
epoch: 003 | train_loss: 1.598, val_loss: 4.911 (best: 4.675)



epoch: 004 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6987142857142857


epoch: 004 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4384268214055448
epoch: 004 | train_loss: 1.513, val_loss: 4.768 (best: 4.675)



epoch: 005 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6981428571428572


epoch: 005 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4480980012894907
epoch: 005 | train_loss: 1.538, val_loss: 4.521 (best: 4.675)
New checkpoint saved to Iter_compress_no_freez_2_4_8_10_12/epoch=05_valloss=4.521.pth.tar



epoch: 006 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.696


epoch: 006 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.44745325596389424
epoch: 006 | train_loss: 1.491, val_loss: 4.841 (best: 4.521)



epoch: 007 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7054285714285714


epoch: 007 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4558349451966473
epoch: 007 | train_loss: 1.468, val_loss: 4.527 (best: 4.521)



epoch: 008 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7052857142857143


epoch: 008 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.46099290780141844
epoch: 008 | train_loss: 1.450, val_loss: 4.373 (best: 4.521)
New checkpoint saved to Iter_compress_no_freez_2_4_8_10_12/epoch=08_valloss=4.373.pth.tar



epoch: 009 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7138571428571429


epoch: 009 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.43197936814958093
epoch: 009 | train_loss: 1.441, val_loss: 4.668 (best: 4.373)



epoch: 010 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7221428571428572


epoch: 010 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4397163120567376
epoch: 010 | train_loss: 1.375, val_loss: 4.685 (best: 4.373)



epoch: 011 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7174285714285714


epoch: 011 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.45003223726627983
epoch: 011 | train_loss: 1.401, val_loss: 4.724 (best: 4.373)



epoch: 012 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7225714285714285


epoch: 012 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.47840103159252095
epoch: 012 | train_loss: 1.339, val_loss: 4.462 (best: 4.373)



epoch: 013 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.728


epoch: 013 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4661508704061896
epoch: 013 | train_loss: 1.331, val_loss: 4.727 (best: 4.373)



epoch: 014 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.7274285714285714


epoch: 014 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4448742746615087
epoch: 014 | train_loss: 1.340, val_loss: 4.924 (best: 4.373)

Best val_loss = 4.373 reached at epoch 8


  0%|          | 0/13 [00:00<?, ?it/s]

Embeddings Truncated
96


epoch: 000 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6118571428571429


epoch: 000 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.42359767891682787
epoch: 000 | train_loss: 1.916, val_loss: 5.013 (best:   inf)
New checkpoint saved to Iter_compress_no_freez_2_4_8_10_12/epoch=00_valloss=5.013.pth.tar



epoch: 001 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.639


epoch: 001 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4274661508704062
epoch: 001 | train_loss: 1.793, val_loss: 4.839 (best: 5.013)
New checkpoint saved to Iter_compress_no_freez_2_4_8_10_12/epoch=01_valloss=4.839.pth.tar



epoch: 002 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6542857142857142


epoch: 002 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.43649258542875563
epoch: 002 | train_loss: 1.698, val_loss: 4.908 (best: 4.839)



epoch: 003 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6611428571428571


epoch: 003 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4455190199871051
epoch: 003 | train_loss: 1.698, val_loss: 4.768 (best: 4.839)
New checkpoint saved to Iter_compress_no_freez_2_4_8_10_12/epoch=03_valloss=4.768.pth.tar



epoch: 004 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6707142857142857


epoch: 004 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4345583494519665
epoch: 004 | train_loss: 1.663, val_loss: 4.780 (best: 4.768)



epoch: 005 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6784285714285714


epoch: 005 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.4448742746615087
epoch: 005 | train_loss: 1.590, val_loss: 4.544 (best: 4.768)
New checkpoint saved to Iter_compress_no_freez_2_4_8_10_12/epoch=05_valloss=4.544.pth.tar



epoch: 006 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6888571428571428


epoch: 006 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.42875564152159895
epoch: 006 | train_loss: 1.537, val_loss: 4.762 (best: 4.544)



epoch: 007 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6864285714285714


epoch: 007 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.44229529335912315
epoch: 007 | train_loss: 1.533, val_loss: 4.822 (best: 4.544)



epoch: 008 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6945714285714286


epoch: 008 | val  :   0%|          | 0/13 [00:00<?, ?it/s]

0.44294003868471954
epoch: 008 | train_loss: 1.511, val_loss: 4.794 (best: 4.544)



epoch: 009 | train:   0%|          | 0/55 [00:00<?, ?it/s]

0.6891428571428572


epoch: 009 | val  :   0%|          | 0/13 [00:00<?, ?it/s]