In [2]:
import sys
import json
import numpy as np
from tqdm.notebook import tqdm
from toolz.curried import pipe, curry, compose

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader

import chnet.ch_tools as ch_tools
import chnet.utilities as ch_utils
import chnet.ch_generator as ch_gen
from chnet.ch_loader import CahnHillDataset
from chnet.models import UNet, UNet_solo_loop, UNet_loop, mse_loss

In [55]:
def train(key="unet", mid=0.0, dif=0.449, dim_x=96, dx=0.25, dt=0.01, 
            gamma=0.2, init_steps=1, nstep=20, n_samples_trn=1024, 
            ngf=32, final_tstep = 5000, num_epochs=10, 
            learning_rate=1.0e-5, n_primes=2000, 
            device="cuda"):
    
    m_l=mid-dif
    m_r=mid+dif
    delta_sim_steps=(final_tstep-init_steps)//nstep
    primes = ch_utils.get_primes(n_primes)
    
    print("no. of datasets: {}".format(len(primes)))
    
    device = torch.device("cuda:0") if device == "cuda" else torch.device("cpu")
    print(device)
    if key == "unet":
        model=UNet(in_channels=1, out_channels=1, init_features=ngf, tanh=False).double().to(device)
    elif key == "unet_solo_loop":
        model=UNet_solo_loop(in_channels=1, out_channels=1, init_features=ngf, temporal=nstep, tanh=False).double().to(device)
    elif key == "unet_loop":
        model=UNet_loop(in_channels=1, out_channels=1, init_features=ngf, temporal=nstep, tanh=False).double().to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    trn_losses = []

    fout = "weights/model_{}_size_{}_step_{}_init_{}_delta_{}_tstep_{}.pt".format(key, ngf, nstep, init_steps, delta_sim_steps, num_epochs*len(primes))  
    print("model saved at: {}".format(fout))

    print("Start Training")
    for num, prime in enumerate(primes):
        # Loss and optimizer
        torch.cuda.empty_cache()
        x_trn, y_trn = ch_gen.data_generator(nsamples=n_samples_trn, 
                                      dim_x=dim_x, 
                                      init_steps=init_steps, 
                                      delta_sim_steps = delta_sim_steps,
                                      dx=dx, 
                                      dt=dt,
                                      m_l=m_l, 
                                      m_r=m_r,
                                      n_step=nstep,
                                      gamma=gamma, 
                                      seed=2513*prime,
                                      device=device)


        trn_dataset = CahnHillDataset(x_trn, y_trn, 
                                      transform_x=lambda x: x[:,None], 
                                      transform_y=lambda x: x[:,None])

        trn_loader = DataLoader(trn_dataset, 
                                batch_size=8, 
                                shuffle=True, 
                                num_workers=4)

        print("Training Run: {}, prime: {}".format(num, prime))

        total_step = len(trn_loader)
        
        for epoch in range(num_epochs):  
            for i, item_trn in enumerate(tqdm(trn_loader)):
                
                model.train()
                
                if "loop" in key:
                    if "solo" in key:
                        x = item_trn['x'][:,0].to(device)
                    else:
                        x = item_trn['x'][:,0].to(device)
                    y_tru = item_trn['y'].to(device)
                else:
                    x = item_trn['x'][:,0].to(device)
                    y_tru = item_trn['y'][:,-1] .to(device) 
                
                y_prd = model(x)# Forward pass
                means_inp = x.mean(axis=(1,2,3))
                means_out = y_prd.mean(axis=(1,2,3))
                loss = mse_loss(y_tru, y_prd, scale=10000) + mse_loss(means_inp, means_out, scale=10000)
                # Backward and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                trn_losses.append(np.sqrt(loss.item()))

            print ('Epoch [{}/{}], Training Loss: {:.11f}'.format(epoch+1, num_epochs, np.mean(trn_losses[-total_step:])))

#         obj = {}
#         obj["state"] = model.state_dict()
#         obj["losses"] = trn_losses
#         torch.save(obj, fout)
    print("End Training")
    return model

In [56]:
arguments = {"key":"unet",
             "mid":0.0, 
             "dif":0.449, 
             "dim_x":96,
             "init_steps":1, 
             "dx":0.25,
             "dt":0.01,
             "gamma":0.2, 
             "n_samples_trn":1024,
             "ngf":64,
             "nstep":2,
             "final_tstep":501,
             "num_epochs":10,
             "n_primes":100,
             "learning_rate":1.0e-5,
             "device":"cuda"}

In [57]:
print(arguments)
model = train(**arguments)

{'key': 'unet', 'mid': 0.0, 'dif': 0.449, 'dim_x': 96, 'init_steps': 1, 'dx': 0.25, 'dt': 0.01, 'gamma': 0.2, 'n_samples_trn': 1024, 'ngf': 64, 'nstep': 2, 'final_tstep': 501, 'num_epochs': 10, 'n_primes': 100, 'learning_rate': 1e-05, 'device': 'cuda'}
no. of datasets: 25
cuda:0
model saved at: weights/model_unet_size_64_step_2_init_1_delta_250_tstep_250.pt
Start Training
Training Run: 0, prime: 2


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 26.55668243341


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 21.58521788465


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 21.32111103601


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 19.27267714665


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 18.12980849381


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 15.72821942210


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 14.06875354893


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 13.68091566525


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 13.94062109030


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 13.35232178746
Training Run: 1, prime: 3


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 12.90852910782


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 12.36938583248


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 12.45757055709


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 12.53355336665


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 11.40633342969


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 11.12333069570


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 11.71391859854


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 12.01009974076


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 12.85229125602


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 11.38395103715
Training Run: 2, prime: 5


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 11.94965189437


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 11.58448712283


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 10.54798889112


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 11.42917007279


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 10.84410851203


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 10.03925776982


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 10.68833257010


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 11.40155433574


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 10.86966844288


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 11.16409667726
Training Run: 3, prime: 7


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 10.22229060344


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 10.81493614372


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 11.48476908676


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 11.20648087667


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 10.63747723147


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 10.25255985366


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 10.37728703248


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 10.16193021370


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 10.58655454121


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 10.13779683003
Training Run: 4, prime: 11


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 10.98457945299


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 11.61673035723


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 10.70552534309


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 10.55870642814


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 12.34735657837


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 11.78743331189


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 10.53916936636


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 9.28266043621


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 12.08274512138


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 11.06488906963
Training Run: 5, prime: 13


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 10.18175080865


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 9.51770852460


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 10.48642345401


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 9.57488632609


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 10.03734198631


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 10.11189170766


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 9.85772592093


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 9.77700216704


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 9.28396924900


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 8.63210486964
Training Run: 6, prime: 17


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 9.81723203209


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 10.25981617605


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 9.99217346130


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 9.82415975849


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 11.38209415752


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 10.12245881036


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 9.51393737557


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 10.34057971032


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 10.47215520977


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 11.44044987586
Training Run: 7, prime: 19


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 10.90975369555


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 10.00137747631


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 9.46607368516


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 10.10095013975


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 9.24263843559


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 9.26117765718


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 8.25123276343


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 10.23979309202


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 10.17367441278


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 10.26685410330
Training Run: 8, prime: 23


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 9.57537255203


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 9.50237588892


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 10.83516750691


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 10.48274254071


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 9.70819470690


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 8.53728910883


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 9.57319505768


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 9.90084494150


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 9.98192914114


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 9.48186078350
Training Run: 9, prime: 29


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 10.04677731087


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 9.33902182001


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 9.64192896377


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 10.34666383072


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 10.01856999231


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 9.56002219294


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 9.60184863427


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 8.45738053936


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 9.11766574261


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 9.65233213449
Training Run: 10, prime: 31


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 9.69747377353


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 9.29904735143


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 9.78258597192


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 8.84940215579


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 10.60871905300


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 10.05669833108


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 8.51616515882


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 9.85630452348


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 9.94368003973


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 8.87908727704
Training Run: 11, prime: 37


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 11.11333866056


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 9.95570967807


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 9.31363316844


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 9.91929308760


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 10.45259824807


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 10.37536524695


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 9.44840106832


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 10.14969655398


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 10.25452481117


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 10.27621679620
Training Run: 12, prime: 41


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 10.33317177275


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 9.44745238815


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 10.65714484339


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 9.41441350595


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 10.11188104082


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 9.66431114077


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 10.05600389195


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 11.12130528727


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 9.83567461512


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 10.42745221947
Training Run: 13, prime: 43


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 9.85705945115


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 10.22754985179


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 9.34356299330


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 9.24041074151


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 9.96704054753


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 9.60822517941


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 9.50357160069


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 9.99787689714


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 8.94846003421


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 10.29816286326
Training Run: 14, prime: 47


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 10.08234958472


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 10.29154538337


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 9.91394490151


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 9.55908166092


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 8.29845440201


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 9.57180079773


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 9.24312202288


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 9.18138651928


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 8.53911808398


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 10.20042791465
Training Run: 15, prime: 53


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 9.19959185640


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 9.17379327472


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 9.13840399986


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 8.49443040423


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 10.34432993665


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 8.89942743830


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 9.20948023743


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 9.25538343417


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 9.14042227162


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 9.34975390944
Training Run: 16, prime: 59


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 10.32184270158


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 9.64959891892


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 9.86245030805


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 9.31605170183


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 9.81956278033


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 9.11534949851


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 9.22811453293


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 8.30078394188


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 9.61034965706


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 9.89704102065
Training Run: 17, prime: 61


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 9.00938408596


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 9.66381157696


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 9.12528016742


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 9.82426782180


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 10.82788457910


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 8.56040443097


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 10.17029875686


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 10.68874522846


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 9.48236768743


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 9.13702260972
Training Run: 18, prime: 67


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 9.03718256331


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 10.39959516657


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 9.34251626606


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 9.32085460487


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 9.77558140343


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 9.47284560571


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 9.69848673951


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 10.16617647493


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 9.59109061308


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 8.89106975841
Training Run: 19, prime: 71


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 9.50617262761


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 9.19063938258


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 10.02722343322


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 8.88659077614


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 8.90506995859


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 9.19969831956


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 10.23022916028


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 8.99675154351


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 10.16877753818


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 9.39332657245
Training Run: 20, prime: 73


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 9.36553381339


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 9.69698706064


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 10.30408502985


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 9.77069188409


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 9.15232224827


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 9.82149381569


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 9.51372856923


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 9.92356156879


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 10.38919437917


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 10.40167283282
Training Run: 21, prime: 79


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 9.35401856190


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 9.74621249826


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 9.41709134423


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 8.46328873438


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 9.61086920997


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 8.94767637351


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 9.97983404258


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 9.81316260485


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 9.94821683594


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 8.58959769488
Training Run: 22, prime: 83


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 9.50666575302


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 9.58749536584


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 10.73162666003


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 9.11962655884


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 9.42477563082


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 9.75656884045


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 9.61671773417


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 9.05633897323


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 9.63585889239


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 9.13498927813
Training Run: 23, prime: 89


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 8.75432597310


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 9.78974318598


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 8.02939580852


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 10.08933220546


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 8.69347493281


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 9.56382473586


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 9.18500921756


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 11.03904177560


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 9.19083795199


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 8.12631734706
Training Run: 24, prime: 97


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/10], Training Loss: 9.49490099764


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/10], Training Loss: 8.78691212321


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/10], Training Loss: 9.59333885229


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/10], Training Loss: 8.67790443300


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/10], Training Loss: 8.84064742763


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [6/10], Training Loss: 8.47765772782


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [7/10], Training Loss: 8.89775005929


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [8/10], Training Loss: 10.20977569279


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [9/10], Training Loss: 9.77959769409


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [10/10], Training Loss: 8.79415833496
End Training


In [62]:
torch.save({"state":model.state_dict()}, "weights/model_unet_size_64_step_2_init_1_delta_250_tstep_250.pt")