In [1]:
# %load 1.1.train_model.py
import commonsetting
from models import perceptual_network, Encoder, Class_out, Conf_out
from dataloader import CustomImageDataset, concatenate_transform_steps
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch import nn
import torch
from tqdm import tqdm
import numpy as np



def determine_training_stops(net,
                             idx_epoch:int,
                             warmup_epochs:int,
                             valid_loss,
                             counts: int        = 0,
                             device             = commonsetting.device,
                             best_valid_loss    = np.inf,
                             tol:float          = 1e-4,
                             f_name:str         = 'temp.h5',
                             ):
    """
    A function in validation determining whether to stop training
    It only works after the warmup 
    Parameters
    ----------
    net : nn.Module
        DESCRIPTION.
    idx_epoch : int
        DESCRIPTION.
    warmup_epochs : int
        DESCRIPTION.
    valid_loss : Tensor
        DESCRIPTION.
    counts : int, optional
        DESCRIPTION. The default is 0.
    device : TYPE, optional
        DESCRIPTION. The default is 'cpu'.
    best_valid_loss : TYPE, optional
        DESCRIPTION. The default is np.inf.
    tol : float, optional
        DESCRIPTION. The default is 1e-4.
    f_name : str, optional
        DESCRIPTION. The default is 'temp.h5'.
    Returns
    -------
    best_valid_loss: Tensor
        DESCRIPTION.
    counts:int
        used for determine when to stop training
    """
    if idx_epoch >= warmup_epochs: # warming up
        temp = valid_loss
        if np.logical_and(temp < best_valid_loss,np.abs(best_valid_loss - temp) >= tol):
            best_valid_loss = valid_loss
            torch.save(net.state_dict(),f_name)# why do i need state_dict()?
            counts = 0
        else:
            counts += 1
    return best_valid_loss,counts

def training_loop(dataloader_train, device, model, loss_function, optimizer):
    model.train(True)
    dataloader_train = tqdm(dataloader_train)
    train_loss = 0.

    for idx_batch, (batch_image, batch_label) in enumerate(dataloader_train):

        batch_label = torch.vstack(batch_label).T.float()
        batch_image = batch_image.to(device)
        batch_label = batch_label.to(device)
        #记得每一次处理数据之前要做这一步
        optmizer.zero_grad()

        features,hidden_representation,prediction, confidence = model(batch_image.to(device))

        class_loss = loss_function(prediction.float(), batch_label.float())
        
        correct_preds = batch_label.clone().detach().argmax(1)==prediction.clone().detach().argmax(1)
        correct_preds = correct_preds.float()

        correct_preds = torch.vstack([1-correct_preds, correct_preds]).T.float()
        


        conf_loss = loss_function(confidence.float(), correct_preds.float())

        combined_loss = class_loss + conf_loss
        train_loss = train_loss + combined_loss.item()
        combined_loss.backward()
        optmizer.step()
        dataloader_train.set_description(f"train loss = {train_loss/(idx_batch + 1):2.6f}")
    
    return model, train_loss


def validation_loop(dataloader_val, device, model, loss_function, optimizer):

    model.eval()
    dataloader_val = tqdm(dataloader_val)
    val_loss = 0.

    with torch.no_grad():
        for idx_batch, (batch_image, batch_label) in enumerate(dataloader_val):
            batch_label = torch.vstack(batch_label).T.float()
            #记得每一次处理数据之前要做这一步
            batch_image = batch_image.to(device)
            batch_label = batch_label.to(device)
            features,hidden_representation,prediction, confidence = model(batch_image.to(device))

            class_loss = loss_function(prediction.float(), batch_label.float())

        
            correct_preds = batch_label.clone().detach().argmax(1)==prediction.clone().detach().argmax(1)
            correct_preds = correct_preds.float()

            correct_preds = torch.vstack([1-correct_preds, correct_preds]).T.float()
            

            conf_loss = loss_function(confidence.float(), correct_preds.float())

            combined_loss = class_loss + conf_loss
            val_loss = val_loss + combined_loss.item()
            dataloader_val.set_description(f"validation loss = {val_loss/(idx_batch + 1):2.6f}")
    return model, val_loss


if __name__ == "__main__":
    tranformer_steps = concatenate_transform_steps(image_resize=commonsetting.image_resize, rotate=45)
    dataset_train = CustomImageDataset(commonsetting.training_dir,label_map=commonsetting.label_map , transform=tranformer_steps)
    dataloader_train = DataLoader(dataset_train, batch_size=commonsetting.batch_size, shuffle=True, num_workers=commonsetting.num_workers)
    dataset_val = CustomImageDataset(commonsetting.val_dir,label_map=commonsetting.label_map , transform=tranformer_steps)
    dataloader_val = DataLoader(dataset_val, batch_size=commonsetting.batch_size, shuffle=True, num_workers=commonsetting.num_workers)
    SimpleCNN = perceptual_network(pretrained_model_name=commonsetting.pretrained_model_name, 
                                   hidden_layer_size=commonsetting.hidden_layer_size, hidden_activation=commonsetting.hidden_activation,
                                   hidden_dropout=commonsetting.hidden_dropout, hidden_layer_type=commonsetting.hidden_layer_type, output_layer_size=commonsetting.output_layer_size, 
                                   confidence_layer_size = commonsetting.confidence_layer_size, in_shape=commonsetting.in_shape, retrain_encoder=commonsetting.retrain_encoder, 
                                   )


    SimpleCNN = SimpleCNN.to(commonsetting.device)
    for p in SimpleCNN.parameters():
        p.requires_grad = False

    for p in SimpleCNN.hidden_layer.parameters():
        p.requires_grad = True

    for p in SimpleCNN.decision_layer.parameters():
        p.requires_grad = True

    for p in SimpleCNN.confidence_layer.parameters():
        p.requires_grad = True

    params = [{"params": SimpleCNN.hidden_layer.parameters(),
               "lr": commonsetting.learning_rate,
               },
               {
                "params": SimpleCNN.decision_layer.parameters(),
               "lr": commonsetting.learning_rate,
               }, 
               {
                "params": SimpleCNN.confidence_layer.parameters(),
               "lr": commonsetting.learning_rate,
               }]

    optmizer = Adam(params, lr=commonsetting.learning_rate)
    loss_fun = nn.BCELoss()
    
    
    loss_fun_conf = nn.BCELoss()
    
    counts = 0
    best_valid_loss = np.inf
    for epoch in range(1000):
        SimpleCNN, train_loss = training_loop(dataloader_train, commonsetting.device, SimpleCNN, loss_fun, optmizer)
        SimpleCNN, val_loss = validation_loop(dataloader_val, commonsetting.device, SimpleCNN, loss_fun, optmizer)
        best_valid_loss, counts = determine_training_stops(SimpleCNN, epoch, warmup_epochs=commonsetting.warmup_epochs, valid_loss=val_loss, counts=counts, 
                                 device=commonsetting.device, best_valid_loss=best_valid_loss, tol=commonsetting.tol, 
                                 f_name="../models/train_pixel_mixed/simplecnn_bs64e4i224h300.h5")
        if counts >= commonsetting.patience:#(len(losses) > patience) and (len(set(losses[-patience:])) == 1):
            break
        else:
            print(f'\nepoch {epoch + 1}, best valid loss = {best_valid_loss:.8f},count = {counts}')




train loss = 1.373465: 100%|██████████| 165/165 [00:17<00:00,  9.18it/s]
validation loss = 1.558822: 100%|██████████| 24/24 [00:03<00:00,  7.09it/s]



epoch 1, best valid loss = inf,count = 0


train loss = 1.332334: 100%|██████████| 165/165 [00:15<00:00, 10.68it/s]
validation loss = 1.360966: 100%|██████████| 24/24 [00:03<00:00,  7.20it/s]



epoch 2, best valid loss = inf,count = 0


train loss = 1.316986: 100%|██████████| 165/165 [00:15<00:00, 10.58it/s]
validation loss = 1.322328: 100%|██████████| 24/24 [00:03<00:00,  7.27it/s]



epoch 3, best valid loss = inf,count = 0


train loss = 1.315067: 100%|██████████| 165/165 [00:15<00:00, 10.73it/s]
validation loss = 1.308221: 100%|██████████| 24/24 [00:03<00:00,  7.41it/s]



epoch 4, best valid loss = 31.39730990,count = 0


train loss = 1.311444: 100%|██████████| 165/165 [00:15<00:00, 10.68it/s]
validation loss = 1.310421: 100%|██████████| 24/24 [00:03<00:00,  7.12it/s]



epoch 5, best valid loss = 31.39730990,count = 1


train loss = 1.304106: 100%|██████████| 165/165 [00:15<00:00, 10.54it/s]
validation loss = 1.297778: 100%|██████████| 24/24 [00:03<00:00,  7.52it/s]



epoch 6, best valid loss = 31.14667571,count = 0


train loss = 1.296633: 100%|██████████| 165/165 [00:15<00:00, 10.39it/s]
validation loss = 1.315995: 100%|██████████| 24/24 [00:03<00:00,  7.49it/s]



epoch 7, best valid loss = 31.14667571,count = 1


train loss = 1.297641: 100%|██████████| 165/165 [00:15<00:00, 10.55it/s]
validation loss = 1.307085: 100%|██████████| 24/24 [00:03<00:00,  7.35it/s]



epoch 8, best valid loss = 31.14667571,count = 2


train loss = 1.303332: 100%|██████████| 165/165 [00:15<00:00, 10.44it/s]
validation loss = 1.323300: 100%|██████████| 24/24 [00:03<00:00,  7.60it/s]



epoch 9, best valid loss = 31.14667571,count = 3


train loss = 1.298248: 100%|██████████| 165/165 [00:15<00:00, 10.55it/s]
validation loss = 1.289744: 100%|██████████| 24/24 [00:02<00:00,  8.32it/s]



epoch 10, best valid loss = 30.95385337,count = 0


train loss = 1.298705: 100%|██████████| 165/165 [00:14<00:00, 11.48it/s]
validation loss = 1.288167: 100%|██████████| 24/24 [00:02<00:00,  8.31it/s]



epoch 11, best valid loss = 30.91601574,count = 0


train loss = 1.296714: 100%|██████████| 165/165 [00:14<00:00, 11.51it/s]
validation loss = 1.304658: 100%|██████████| 24/24 [00:03<00:00,  7.99it/s]



epoch 12, best valid loss = 30.91601574,count = 1


train loss = 1.296984: 100%|██████████| 165/165 [00:14<00:00, 11.47it/s]
validation loss = 1.285163: 100%|██████████| 24/24 [00:02<00:00,  8.16it/s]



epoch 13, best valid loss = 30.84391797,count = 0


train loss = 1.294294: 100%|██████████| 165/165 [00:15<00:00, 10.55it/s]
validation loss = 1.279710: 100%|██████████| 24/24 [00:02<00:00,  8.48it/s]



epoch 14, best valid loss = 30.71303391,count = 0


train loss = 1.293441: 100%|██████████| 165/165 [00:14<00:00, 11.54it/s]
validation loss = 1.285424: 100%|██████████| 24/24 [00:02<00:00,  8.53it/s]



epoch 15, best valid loss = 30.71303391,count = 1


train loss = 1.293039: 100%|██████████| 165/165 [00:14<00:00, 11.65it/s]
validation loss = 1.294545: 100%|██████████| 24/24 [00:02<00:00,  8.43it/s]



epoch 16, best valid loss = 30.71303391,count = 2


train loss = 1.291587: 100%|██████████| 165/165 [00:14<00:00, 11.68it/s]
validation loss = 1.282673: 100%|██████████| 24/24 [00:02<00:00,  8.48it/s]



epoch 17, best valid loss = 30.71303391,count = 3


train loss = 1.294683: 100%|██████████| 165/165 [00:14<00:00, 11.49it/s]
validation loss = 1.273195: 100%|██████████| 24/24 [00:02<00:00,  8.51it/s]



epoch 18, best valid loss = 30.55668402,count = 0


train loss = 1.290854: 100%|██████████| 165/165 [00:14<00:00, 11.75it/s]
validation loss = 1.286510: 100%|██████████| 24/24 [00:02<00:00,  8.82it/s]



epoch 19, best valid loss = 30.55668402,count = 1


train loss = 1.288276: 100%|██████████| 165/165 [00:14<00:00, 11.73it/s]
validation loss = 1.291828: 100%|██████████| 24/24 [00:02<00:00,  8.26it/s]



epoch 20, best valid loss = 30.55668402,count = 2


train loss = 1.289750: 100%|██████████| 165/165 [00:14<00:00, 11.69it/s]
validation loss = 1.282001: 100%|██████████| 24/24 [00:02<00:00,  8.83it/s]



epoch 21, best valid loss = 30.55668402,count = 3


train loss = 1.288874: 100%|██████████| 165/165 [00:14<00:00, 11.66it/s]
validation loss = 1.292189: 100%|██████████| 24/24 [00:02<00:00,  8.40it/s]



epoch 22, best valid loss = 30.55668402,count = 4


train loss = 1.288790: 100%|██████████| 165/165 [00:14<00:00, 11.47it/s]
validation loss = 1.282663: 100%|██████████| 24/24 [00:02<00:00,  8.36it/s]



epoch 23, best valid loss = 30.55668402,count = 5


train loss = 1.285902: 100%|██████████| 165/165 [00:14<00:00, 11.45it/s]
validation loss = 1.280845: 100%|██████████| 24/24 [00:02<00:00,  8.43it/s]



epoch 24, best valid loss = 30.55668402,count = 6


train loss = 1.285835: 100%|██████████| 165/165 [00:14<00:00, 11.52it/s]
validation loss = 1.280695: 100%|██████████| 24/24 [00:02<00:00,  8.02it/s]



epoch 25, best valid loss = 30.55668402,count = 7


train loss = 1.290268: 100%|██████████| 165/165 [00:14<00:00, 11.49it/s]
validation loss = 1.286156: 100%|██████████| 24/24 [00:02<00:00,  8.16it/s]



epoch 26, best valid loss = 30.55668402,count = 8


train loss = 1.286363: 100%|██████████| 165/165 [00:14<00:00, 11.53it/s]
validation loss = 1.286415: 100%|██████████| 24/24 [00:03<00:00,  7.98it/s]



epoch 27, best valid loss = 30.55668402,count = 9


train loss = 1.283534: 100%|██████████| 165/165 [00:14<00:00, 11.50it/s]
validation loss = 1.281514: 100%|██████████| 24/24 [00:03<00:00,  7.84it/s]



epoch 28, best valid loss = 30.55668402,count = 10


train loss = 1.285630: 100%|██████████| 165/165 [00:14<00:00, 11.58it/s]
validation loss = 1.290145: 100%|██████████| 24/24 [00:02<00:00,  8.41it/s]



epoch 29, best valid loss = 30.55668402,count = 11


train loss = 1.287744: 100%|██████████| 165/165 [00:14<00:00, 11.48it/s]
validation loss = 1.278110: 100%|██████████| 24/24 [00:02<00:00,  8.59it/s]



epoch 30, best valid loss = 30.55668402,count = 12


train loss = 1.285583: 100%|██████████| 165/165 [00:15<00:00, 10.60it/s]
validation loss = 1.274959: 100%|██████████| 24/24 [00:03<00:00,  7.35it/s]



epoch 31, best valid loss = 30.55668402,count = 13


train loss = 1.286679: 100%|██████████| 165/165 [00:15<00:00, 10.67it/s]
validation loss = 1.283510: 100%|██████████| 24/24 [00:02<00:00,  8.11it/s]



epoch 32, best valid loss = 30.55668402,count = 14


train loss = 1.283919: 100%|██████████| 165/165 [00:14<00:00, 11.46it/s]
validation loss = 1.278450: 100%|██████████| 24/24 [00:02<00:00,  8.41it/s]



epoch 33, best valid loss = 30.55668402,count = 15


train loss = 1.283048: 100%|██████████| 165/165 [00:14<00:00, 11.21it/s]
validation loss = 1.273300: 100%|██████████| 24/24 [00:02<00:00,  8.37it/s]



epoch 34, best valid loss = 30.55668402,count = 16


train loss = 1.281756: 100%|██████████| 165/165 [00:15<00:00, 10.82it/s]
validation loss = 1.280212: 100%|██████████| 24/24 [00:03<00:00,  7.74it/s]



epoch 35, best valid loss = 30.55668402,count = 17


train loss = 1.281986: 100%|██████████| 165/165 [00:15<00:00, 10.74it/s]
validation loss = 1.287300: 100%|██████████| 24/24 [00:03<00:00,  7.76it/s]



epoch 36, best valid loss = 30.55668402,count = 18


train loss = 1.283283: 100%|██████████| 165/165 [00:14<00:00, 11.39it/s]
validation loss = 1.269325: 100%|██████████| 24/24 [00:03<00:00,  7.69it/s]



epoch 37, best valid loss = 30.46379697,count = 0


train loss = 1.280838: 100%|██████████| 165/165 [00:14<00:00, 11.36it/s]
validation loss = 1.271264: 100%|██████████| 24/24 [00:02<00:00,  8.28it/s]



epoch 38, best valid loss = 30.46379697,count = 1


train loss = 1.279306: 100%|██████████| 165/165 [00:14<00:00, 11.48it/s]
validation loss = 1.275604: 100%|██████████| 24/24 [00:02<00:00,  8.39it/s]



epoch 39, best valid loss = 30.46379697,count = 2


train loss = 1.281770: 100%|██████████| 165/165 [00:14<00:00, 11.38it/s]
validation loss = 1.290841: 100%|██████████| 24/24 [00:02<00:00,  8.18it/s]



epoch 40, best valid loss = 30.46379697,count = 3


train loss = 1.286764: 100%|██████████| 165/165 [00:14<00:00, 11.14it/s]
validation loss = 1.277269: 100%|██████████| 24/24 [00:02<00:00,  8.10it/s]



epoch 41, best valid loss = 30.46379697,count = 4


train loss = 1.281169: 100%|██████████| 165/165 [00:14<00:00, 11.51it/s]
validation loss = 1.265894: 100%|██████████| 24/24 [00:03<00:00,  7.70it/s]



epoch 42, best valid loss = 30.38145363,count = 0


train loss = 1.279178: 100%|██████████| 165/165 [00:14<00:00, 11.73it/s]
validation loss = 1.261537: 100%|██████████| 24/24 [00:02<00:00,  8.25it/s]



epoch 43, best valid loss = 30.27689183,count = 0


train loss = 1.279267: 100%|██████████| 165/165 [00:14<00:00, 11.62it/s]
validation loss = 1.270977: 100%|██████████| 24/24 [00:03<00:00,  7.99it/s]



epoch 44, best valid loss = 30.27689183,count = 1


train loss = 1.280373: 100%|██████████| 165/165 [00:14<00:00, 11.08it/s]
validation loss = 1.289353: 100%|██████████| 24/24 [00:03<00:00,  7.68it/s]



epoch 45, best valid loss = 30.27689183,count = 2


train loss = 1.279218: 100%|██████████| 165/165 [00:15<00:00, 10.85it/s]
validation loss = 1.281051: 100%|██████████| 24/24 [00:02<00:00,  8.28it/s]



epoch 46, best valid loss = 30.27689183,count = 3


train loss = 1.286456: 100%|██████████| 165/165 [00:14<00:00, 11.33it/s]
validation loss = 1.267780: 100%|██████████| 24/24 [00:02<00:00,  8.67it/s]



epoch 47, best valid loss = 30.27689183,count = 4


train loss = 1.279961: 100%|██████████| 165/165 [00:14<00:00, 11.49it/s]
validation loss = 1.285452: 100%|██████████| 24/24 [00:02<00:00,  8.33it/s]



epoch 48, best valid loss = 30.27689183,count = 5


train loss = 1.278901: 100%|██████████| 165/165 [00:14<00:00, 11.41it/s]
validation loss = 1.261215: 100%|██████████| 24/24 [00:02<00:00,  8.06it/s]



epoch 49, best valid loss = 30.26916599,count = 0


train loss = 1.277982: 100%|██████████| 165/165 [00:14<00:00, 11.57it/s]
validation loss = 1.283352: 100%|██████████| 24/24 [00:02<00:00,  8.58it/s]



epoch 50, best valid loss = 30.26916599,count = 1


train loss = 1.278794: 100%|██████████| 165/165 [00:14<00:00, 11.45it/s]
validation loss = 1.265409: 100%|██████████| 24/24 [00:02<00:00,  8.45it/s]



epoch 51, best valid loss = 30.26916599,count = 2


train loss = 1.278610: 100%|██████████| 165/165 [00:15<00:00, 10.76it/s]
validation loss = 1.262045: 100%|██████████| 24/24 [00:03<00:00,  7.42it/s]



epoch 52, best valid loss = 30.26916599,count = 3


train loss = 1.277477: 100%|██████████| 165/165 [00:15<00:00, 10.84it/s]
validation loss = 1.259781: 100%|██████████| 24/24 [00:03<00:00,  7.68it/s]



epoch 53, best valid loss = 30.23473728,count = 0


train loss = 1.276419: 100%|██████████| 165/165 [00:14<00:00, 11.09it/s]
validation loss = 1.263688: 100%|██████████| 24/24 [00:02<00:00,  8.04it/s]



epoch 54, best valid loss = 30.23473728,count = 1


train loss = 1.273104: 100%|██████████| 165/165 [00:15<00:00, 10.82it/s]
validation loss = 1.264842: 100%|██████████| 24/24 [00:02<00:00,  8.18it/s]



epoch 55, best valid loss = 30.23473728,count = 2


train loss = 1.280727: 100%|██████████| 165/165 [00:14<00:00, 11.32it/s]
validation loss = 1.261723: 100%|██████████| 24/24 [00:02<00:00,  8.45it/s]



epoch 56, best valid loss = 30.23473728,count = 3


train loss = 1.280663: 100%|██████████| 165/165 [00:14<00:00, 11.37it/s]
validation loss = 1.253806: 100%|██████████| 24/24 [00:02<00:00,  8.09it/s]



epoch 57, best valid loss = 30.09134054,count = 0


train loss = 1.279798: 100%|██████████| 165/165 [00:14<00:00, 11.25it/s]
validation loss = 1.298640: 100%|██████████| 24/24 [00:02<00:00,  8.14it/s]



epoch 58, best valid loss = 30.09134054,count = 1


train loss = 1.281667: 100%|██████████| 165/165 [00:14<00:00, 11.14it/s]
validation loss = 1.271245: 100%|██████████| 24/24 [00:02<00:00,  8.37it/s]



epoch 59, best valid loss = 30.09134054,count = 2


train loss = 1.274551: 100%|██████████| 165/165 [00:14<00:00, 11.62it/s]
validation loss = 1.260061: 100%|██████████| 24/24 [00:02<00:00,  8.56it/s]



epoch 60, best valid loss = 30.09134054,count = 3


train loss = 1.272450: 100%|██████████| 165/165 [00:14<00:00, 11.51it/s]
validation loss = 1.274683: 100%|██████████| 24/24 [00:02<00:00,  8.43it/s]



epoch 61, best valid loss = 30.09134054,count = 4


train loss = 1.281234: 100%|██████████| 165/165 [00:14<00:00, 11.70it/s]
validation loss = 1.262345: 100%|██████████| 24/24 [00:02<00:00,  8.46it/s]



epoch 62, best valid loss = 30.09134054,count = 5


train loss = 1.278765: 100%|██████████| 165/165 [00:14<00:00, 11.78it/s]
validation loss = 1.264330: 100%|██████████| 24/24 [00:02<00:00,  8.50it/s]



epoch 63, best valid loss = 30.09134054,count = 6


train loss = 1.280715: 100%|██████████| 165/165 [00:14<00:00, 11.64it/s]
validation loss = 1.264778: 100%|██████████| 24/24 [00:02<00:00,  8.64it/s]



epoch 64, best valid loss = 30.09134054,count = 7


train loss = 1.277496: 100%|██████████| 165/165 [00:14<00:00, 11.69it/s]
validation loss = 1.260675: 100%|██████████| 24/24 [00:02<00:00,  8.47it/s]



epoch 65, best valid loss = 30.09134054,count = 8


train loss = 1.274075: 100%|██████████| 165/165 [00:14<00:00, 11.68it/s]
validation loss = 1.267670: 100%|██████████| 24/24 [00:02<00:00,  8.29it/s]



epoch 66, best valid loss = 30.09134054,count = 9


train loss = 1.278751: 100%|██████████| 165/165 [00:14<00:00, 11.59it/s]
validation loss = 1.273768: 100%|██████████| 24/24 [00:02<00:00,  8.53it/s]



epoch 67, best valid loss = 30.09134054,count = 10


train loss = 1.274116: 100%|██████████| 165/165 [00:14<00:00, 11.57it/s]
validation loss = 1.272257: 100%|██████████| 24/24 [00:02<00:00,  8.56it/s]



epoch 68, best valid loss = 30.09134054,count = 11


train loss = 1.276895: 100%|██████████| 165/165 [00:14<00:00, 11.75it/s]
validation loss = 1.261155: 100%|██████████| 24/24 [00:02<00:00,  8.43it/s]



epoch 69, best valid loss = 30.09134054,count = 12


train loss = 1.276290: 100%|██████████| 165/165 [00:14<00:00, 11.57it/s]
validation loss = 1.261511: 100%|██████████| 24/24 [00:02<00:00,  8.46it/s]



epoch 70, best valid loss = 30.09134054,count = 13


train loss = 1.274568: 100%|██████████| 165/165 [00:14<00:00, 11.55it/s]
validation loss = 1.279248: 100%|██████████| 24/24 [00:02<00:00,  8.25it/s]



epoch 71, best valid loss = 30.09134054,count = 14


train loss = 1.278554: 100%|██████████| 165/165 [00:14<00:00, 11.77it/s]
validation loss = 1.267244: 100%|██████████| 24/24 [00:02<00:00,  8.07it/s]



epoch 72, best valid loss = 30.09134054,count = 15


train loss = 1.271434: 100%|██████████| 165/165 [00:13<00:00, 11.80it/s]
validation loss = 1.247473: 100%|██████████| 24/24 [00:02<00:00,  8.32it/s]



epoch 73, best valid loss = 29.93935287,count = 0


train loss = 1.274967: 100%|██████████| 165/165 [00:14<00:00, 11.39it/s]
validation loss = 1.265133: 100%|██████████| 24/24 [00:02<00:00,  8.07it/s]



epoch 74, best valid loss = 29.93935287,count = 1


train loss = 1.273424: 100%|██████████| 165/165 [00:14<00:00, 11.74it/s]
validation loss = 1.259205: 100%|██████████| 24/24 [00:03<00:00,  7.94it/s]



epoch 75, best valid loss = 29.93935287,count = 2


train loss = 1.278895: 100%|██████████| 165/165 [00:14<00:00, 11.63it/s]
validation loss = 1.271061: 100%|██████████| 24/24 [00:02<00:00,  8.38it/s]



epoch 76, best valid loss = 29.93935287,count = 3


train loss = 1.276802: 100%|██████████| 165/165 [00:14<00:00, 11.13it/s]
validation loss = 1.273211: 100%|██████████| 24/24 [00:02<00:00,  8.45it/s]



epoch 77, best valid loss = 29.93935287,count = 4


train loss = 1.277662: 100%|██████████| 165/165 [00:14<00:00, 11.64it/s]
validation loss = 1.240782: 100%|██████████| 24/24 [00:02<00:00,  8.52it/s]



epoch 78, best valid loss = 29.77876544,count = 0


train loss = 1.274922: 100%|██████████| 165/165 [00:14<00:00, 11.48it/s]
validation loss = 1.258457: 100%|██████████| 24/24 [00:02<00:00,  8.57it/s]



epoch 79, best valid loss = 29.77876544,count = 1


train loss = 1.277167: 100%|██████████| 165/165 [00:13<00:00, 11.82it/s]
validation loss = 1.250494: 100%|██████████| 24/24 [00:02<00:00,  8.33it/s]



epoch 80, best valid loss = 29.77876544,count = 2


train loss = 1.275398: 100%|██████████| 165/165 [00:14<00:00, 11.26it/s]
validation loss = 1.257410: 100%|██████████| 24/24 [00:02<00:00,  8.46it/s]



epoch 81, best valid loss = 29.77876544,count = 3


train loss = 1.275971: 100%|██████████| 165/165 [00:14<00:00, 11.61it/s]
validation loss = 1.258798: 100%|██████████| 24/24 [00:02<00:00,  8.08it/s]



epoch 82, best valid loss = 29.77876544,count = 4


train loss = 1.271414: 100%|██████████| 165/165 [00:14<00:00, 11.55it/s]
validation loss = 1.264353: 100%|██████████| 24/24 [00:02<00:00,  8.05it/s]



epoch 83, best valid loss = 29.77876544,count = 5


train loss = 1.274178: 100%|██████████| 165/165 [00:14<00:00, 11.62it/s]
validation loss = 1.274132: 100%|██████████| 24/24 [00:02<00:00,  8.51it/s]



epoch 84, best valid loss = 29.77876544,count = 6


train loss = 1.273914: 100%|██████████| 165/165 [00:14<00:00, 11.69it/s]
validation loss = 1.261426: 100%|██████████| 24/24 [00:02<00:00,  8.53it/s]



epoch 85, best valid loss = 29.77876544,count = 7


train loss = 1.274003: 100%|██████████| 165/165 [00:14<00:00, 11.52it/s]
validation loss = 1.263500: 100%|██████████| 24/24 [00:02<00:00,  8.40it/s]



epoch 86, best valid loss = 29.77876544,count = 8


train loss = 1.268250: 100%|██████████| 165/165 [00:14<00:00, 11.56it/s]
validation loss = 1.275760: 100%|██████████| 24/24 [00:02<00:00,  8.45it/s]



epoch 87, best valid loss = 29.77876544,count = 9


train loss = 1.270465: 100%|██████████| 165/165 [00:14<00:00, 11.61it/s]
validation loss = 1.255488: 100%|██████████| 24/24 [00:02<00:00,  8.79it/s]



epoch 88, best valid loss = 29.77876544,count = 10


train loss = 1.270700: 100%|██████████| 165/165 [00:14<00:00, 11.51it/s]
validation loss = 1.261228: 100%|██████████| 24/24 [00:02<00:00,  8.53it/s]



epoch 89, best valid loss = 29.77876544,count = 11


train loss = 1.269907: 100%|██████████| 165/165 [00:14<00:00, 11.38it/s]
validation loss = 1.266781: 100%|██████████| 24/24 [00:02<00:00,  8.42it/s]



epoch 90, best valid loss = 29.77876544,count = 12


train loss = 1.271997: 100%|██████████| 165/165 [00:14<00:00, 11.68it/s]
validation loss = 1.259169: 100%|██████████| 24/24 [00:02<00:00,  8.26it/s]



epoch 91, best valid loss = 29.77876544,count = 13


train loss = 1.273287: 100%|██████████| 165/165 [00:14<00:00, 11.59it/s]
validation loss = 1.277349: 100%|██████████| 24/24 [00:02<00:00,  8.61it/s]



epoch 92, best valid loss = 29.77876544,count = 14


train loss = 1.270422: 100%|██████████| 165/165 [00:14<00:00, 11.51it/s]
validation loss = 1.266974: 100%|██████████| 24/24 [00:02<00:00,  8.40it/s]



epoch 93, best valid loss = 29.77876544,count = 15


train loss = 1.273985: 100%|██████████| 165/165 [00:14<00:00, 11.58it/s]
validation loss = 1.256354: 100%|██████████| 24/24 [00:02<00:00,  8.09it/s]



epoch 94, best valid loss = 29.77876544,count = 16


train loss = 1.274049: 100%|██████████| 165/165 [00:14<00:00, 11.59it/s]
validation loss = 1.265783: 100%|██████████| 24/24 [00:02<00:00,  8.42it/s]



epoch 95, best valid loss = 29.77876544,count = 17


train loss = 1.270693: 100%|██████████| 165/165 [00:14<00:00, 11.75it/s]
validation loss = 1.272172: 100%|██████████| 24/24 [00:02<00:00,  8.46it/s]



epoch 96, best valid loss = 29.77876544,count = 18


train loss = 1.271090: 100%|██████████| 165/165 [00:13<00:00, 11.82it/s]
validation loss = 1.265549: 100%|██████████| 24/24 [00:02<00:00,  8.83it/s]



epoch 97, best valid loss = 29.77876544,count = 19


train loss = 1.273049: 100%|██████████| 165/165 [00:13<00:00, 11.80it/s]
validation loss = 1.258720: 100%|██████████| 24/24 [00:02<00:00,  8.46it/s]


In [None]:
!/usr/bin/shutdown