In [11]:
import sys
import json
import numpy as np
from tqdm.notebook import tqdm
from toolz.curried import pipe, curry, compose

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader

import chnet.ch_tools as ch_tools
import chnet.utilities as ch_utils
import chnet.ch_generator as ch_gen
from chnet.ch_loader import CahnHillDataset
from chnet.models import get_model

In [13]:
def train(key="unet", 
          ngf=32,
          tanh=True,
          mid=0.0, 
          dif=0.449, 
          dim_x=96, 
          dx=0.25, 
          dt=0.01, 
          gamma=0.2, 
          init_steps=1, 
          nstep=5, 
          n_samples_trn=1024, 
          n_datasets=10, 
          final_tstep=1001, 
          num_epochs=10, 
          learning_rate=1.0e-5,
          optimizer="sgd", 
          criterion="mae",
          scale=100,
          device="cuda", 
          save=True, 
          tag="script", 
          fldr="weights"):
    
    device = torch.device("cuda:0") if device == "cuda" else torch.device("cpu")
    print(device)

    model = get_model(key=key, ngf=ngf, tanh=tanh, nstep=nstep, device=device)
    
    delta_sim_steps=(final_tstep-init_steps)//nstep
    primes = ch_utils.get_primes(50000)[:n_datasets]
    print("no. of datasets: {}".format(len(primes)))
    fout = "{}/model_{}_size_{}_step_{}_init_{}_delta_{}_tstep_{}_tanh_{}_loss_{}_tag_{}.pt".format(fldr, key, ngf, nstep, init_steps, delta_sim_steps, num_epochs*len(primes), tanh, criterion, tag)  
    print("model saved at: {}".format(fout))

    print("Start Training")
    trn_losses = []
    criterion= get_criterion(criterion=criterion, scale=scale)
    if optimizer == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    elif optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    print(criterion)
    print(optimizer)
    for num, prime in enumerate(primes):
        # Loss and optimizer
        torch.cuda.empty_cache()
        x_trn, y_trn = ch_gen.data_generator(nsamples=n_samples_trn, 
                                      dim_x=dim_x, 
                                      init_steps=init_steps, 
                                      delta_sim_steps = delta_sim_steps,
                                      dx=dx, 
                                      dt=dt,
                                      m_l=mid-dif, 
                                      m_r=mid+dif,
                                      n_step=nstep,
                                      gamma=gamma, 
                                      seed=2513*prime,
                                      device=device)


        trn_dataset = CahnHillDataset(x_trn, y_trn, 
                                      transform_x=lambda x: x[:,None], 
                                      transform_y=lambda x: x[:,None])

        trn_loader = DataLoader(trn_dataset, 
                                batch_size=8, 
                                shuffle=True, 
                                num_workers=4)

        print("Training Run: {}".format(num+1))

        total_step = len(trn_loader)
        
        for epoch in range(num_epochs):  
            for i, item_trn in enumerate(tqdm(trn_loader)):
                
                model.train()
                
                if "loop" in key:
                    if "solo" in key:
                        x = item_trn['x'][:,0].to(device)
                    else:
                        x = item_trn['x'][:,0].to(device)
                    y_tru = item_trn['y'].to(device)
                else:
                    x = item_trn['x'][:,0].to(device)
                    y_tru = item_trn['y'][:,-1] .to(device) 
  
                y_prd = model(x)# Forward pass
                # means_inp = x.mean(axis=(1,2,3))
                # means_out = y_prd.mean(axis=(1,2,3))
                loss = criterion(y_tru, y_prd)
                # Backward and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                trn_losses.append(np.sqrt(loss.item()))

            print ('Epoch [{}/{}], Training Loss: {:.11f}'.format(epoch+1, num_epochs, np.mean(trn_losses[-total_step:])))
            # print ("Means, inp: {:1.3f}, out: {:1.3f}".format(x.mean(axis=(1,2,3)).data, y_prd.mean(axis=(1,2,3))).data)
            obj = {}
            obj["state"] = model.state_dict()
            obj["losses"] = trn_losses
            if save:
                torch.save(obj, fout)
    print("End Training")
    return obj

def validate(key="unet", 
              ngf=32,
              tanh=True,
              conv=True,
              mid=0.0, 
              dif=0.449, 
              dim_x=96, 
              dx=0.25, 
              dt=0.01, 
              gamma=0.2, 
              nstep=20,
              init_steps=1, 
              n_samples=1024, 
              final_tstep=5000, 
              seed=8634132, 
              device="cuda", 
              weight_file="",
              vis=True):
    
    from chnet.ssim import SSIM
    from chnet.mssim import MSSSIM
    ssim_loss = SSIM(window_size=11)
    mssim_loss = MSSSIM(window_size=11, channel=1) 
    mae_loss_npy = lambda x1, x2: np.mean(np.fabs(x1-x2))
    
    device = torch.device("cuda:0") if device == "cuda" else torch.device("cpu")
    print(device)
    
    model = get_model(key=key, ngf=ngf, tanh=tanh, nstep=nstep, conv=conv, device=device)
    model.load_state_dict(torch.load(weight_file, map_location=device)["state"])

    print("Start Validation")
    torch.cuda.empty_cache()
    x_val, y_val = ch_gen.data_generator(nsamples=n_samples, 
                                  dim_x=dim_x, 
                                  init_steps=init_steps, 
                                  delta_sim_steps = (final_tstep-init_steps)//nstep,
                                  dx=dx, 
                                  dt=dt,
                                  m_l=mid-dif, 
                                  m_r=mid+dif,
                                  n_step=nstep,
                                  gamma=gamma, 
                                  seed=seed,
                                  device=device)
    
    val_dataset = CahnHillDataset(x_val, y_val, 
                                  transform_x=lambda x: x[:,None], 
                                  transform_y=lambda x: x[:,None])

    torch.cuda.empty_cache()
    model.eval()
    
    errs = []
    for ix in tqdm(range(len(val_dataset))):

        item_v = val_dataset[ix]

        x = item_v['x'][None][:,0].double().to(device)
        y_tru = item_v['y'][None][:,-1].double().to(device) 

        if "loop" in key:
            y_prd=model(x)[:,-1]
        else:
            y_prd=model(x)
        im_x = x[0,0].detach().cpu().numpy()
        im_y1 = y_tru[0,0].detach().cpu().numpy()
        im_y2 = y_prd[0,0].detach().cpu().numpy()
        errs.append(mae_loss_npy(im_y1, im_y2))

        if vis:
            if ((ix+1) % (n_samples//5)) == 0:
                ch_utils.draw_by_side([im_x, im_y1, im_y2], 
                                      sub_titles=["inp", "sim", "cnn"], 
                                      scale=8, vmin=None, vmax=None)
                
                
                print("{:1.3f}, {:1.3f}, {:1.3f}".format(ssim_loss(y_tru, x).item(), 
                      ssim_loss(y_tru, y_tru).item(), 
                      ssim_loss(y_tru, y_prd).item()))
                
                print("{:1.3f}, {:1.3f}, {:1.3f}".format(mssim_loss(y_tru.float(), x.float()).item(), 
                      mssim_loss(y_tru.float(), y_tru.float()).item(), 
                      mssim_loss(y_tru.float(), y_prd.float()).item()))
    return errs

In [6]:
@curry
def mse_loss(y1, y2, scale=1.):
    """standard MSE definition"""
    return ((y1 - y2) ** 2).sum() / y1.data.nelement() * scale

@curry
def mae_loss(y1, y2, scale=1.):
    """standard MAE definition"""
    return (torch.abs(y1 - y2)).sum() / y1.data.nelement() * scale


@curry
def ssim_loss(y1, y2, scale=11):
    from pytorch_ssim import SSIM
    ssim = SSIM(window_size=scale, size_average=True)
    return 1. - 0.5 * ssim(y1, y2)

In [56]:
# obj = train(key="unet", 
#             ngf=32, 
#             mid=0.0, 
#             dif=0.449, 
#             dim_x=96, 
#             dx=0.25, 
#             dt=0.01, 
#             gamma=0.2, 
#             n_samples_trn=1024,  
#             nstep=2, 
#             init_steps=101, 
#             final_tstep=501, 
#             num_epochs=5, 
#             n_datasets=5, 
#             learning_rate=1e-05, 
#             device="cuda", 
#             save=True,
#             tanh=False,
#             criterion=ssim_loss(scale=13),
#             tag="jnbk_ssim13")

cuda:0
no. of datasets: 5
model saved at: weights/model_unet_size_32_step_2_init_101_delta_200_tstep_25_tanh_False_tag_jnbk_ssim13.pt
Start Training
Training Run: 1


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/5], Training Loss: 0.97293197353


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/5], Training Loss: 0.92828625812


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/5], Training Loss: 0.90279005541


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/5], Training Loss: 0.89756930118


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/5], Training Loss: 0.88270582208
Training Run: 2


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/5], Training Loss: 0.87436310842


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/5], Training Loss: 0.86930140594


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/5], Training Loss: 0.86264845730


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/5], Training Loss: 0.85756668626


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/5], Training Loss: 0.85593930894
Training Run: 3


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/5], Training Loss: 0.85318745240


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/5], Training Loss: 0.84705152770


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/5], Training Loss: 0.84369811609


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/5], Training Loss: 0.83388035231


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/5], Training Loss: 0.82983705843
Training Run: 4


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/5], Training Loss: 0.82619466723


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/5], Training Loss: 0.82364848610


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/5], Training Loss: 0.81680103447


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/5], Training Loss: 0.81725524462


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/5], Training Loss: 0.81487793386
Training Run: 5


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [1/5], Training Loss: 0.81635644252


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [2/5], Training Loss: 0.81659552626


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [3/5], Training Loss: 0.81003154484


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [4/5], Training Loss: 0.80585671305


HBox(children=(FloatProgress(value=0.0, max=128.0), HTML(value='')))


Epoch [5/5], Training Loss: 0.80485255899
End Training


In [14]:
weight_file="weights/model_unet_size_128_step_2_init_1_delta_2500_tstep_300_tanh_True_loss_mae_tag_testbed.pt"
errs = validate(key="unet", 
                ngf=64,
                tanh=True,
                conv=True,
                mid=0.0, 
                dif=0.449, 
                dim_x=96,
                dx=0.25, 
                dt=0.01, 
                gamma=0.2, 
                n_samples=1024, 
                nstep=2,
                init_steps=1, 
                final_tstep=5001, 
                seed=8634132, 
                device="cuda", 
                weight_file=weight_file)
print("MAE mean: {:1.3f}, std: {:1.3f}".format(np.mean(errs), np.std(errs)))

cuda:0


RuntimeError: Error(s) in loading state_dict for UNet:
	size mismatch for encoder1.enc1conv1.weight: copying a param with shape torch.Size([128, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 1, 3, 3]).
	size mismatch for encoder1.enc1norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for encoder1.enc1norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for encoder1.enc1norm1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for encoder1.enc1norm1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for encoder1.enc1conv2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for encoder1.enc1norm2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for encoder1.enc1norm2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for encoder1.enc1norm2.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for encoder1.enc1norm2.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for encoder2.enc2conv1.weight: copying a param with shape torch.Size([256, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3]).
	size mismatch for encoder2.enc2norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder2.enc2norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder2.enc2norm1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder2.enc2norm1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder2.enc2conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for encoder2.enc2norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder2.enc2norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder2.enc2norm2.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder2.enc2norm2.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder3.enc3conv1.weight: copying a param with shape torch.Size([512, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 128, 3, 3]).
	size mismatch for encoder3.enc3norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder3.enc3norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder3.enc3norm1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder3.enc3norm1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder3.enc3conv2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for encoder3.enc3norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder3.enc3norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder3.enc3norm2.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder3.enc3norm2.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder4.enc4conv1.weight: copying a param with shape torch.Size([1024, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 256, 3, 3]).
	size mismatch for encoder4.enc4norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder4.enc4norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder4.enc4norm1.running_mean: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder4.enc4norm1.running_var: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder4.enc4conv2.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for encoder4.enc4norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder4.enc4norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder4.enc4norm2.running_mean: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder4.enc4norm2.running_var: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for bottleneck.bottleneckconv1.weight: copying a param with shape torch.Size([2048, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 512, 3, 3]).
	size mismatch for bottleneck.bottlenecknorm1.weight: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for bottleneck.bottlenecknorm1.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for bottleneck.bottlenecknorm1.running_mean: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for bottleneck.bottlenecknorm1.running_var: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for bottleneck.bottleneckconv2.weight: copying a param with shape torch.Size([2048, 2048, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
	size mismatch for bottleneck.bottlenecknorm2.weight: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for bottleneck.bottlenecknorm2.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for bottleneck.bottlenecknorm2.running_mean: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for bottleneck.bottlenecknorm2.running_var: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for upconv4.weight: copying a param with shape torch.Size([2048, 1024, 2, 2]) from checkpoint, the shape in current model is torch.Size([1024, 512, 2, 2]).
	size mismatch for upconv4.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder4.dec4conv1.weight: copying a param with shape torch.Size([1024, 2048, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 1024, 3, 3]).
	size mismatch for decoder4.dec4norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder4.dec4norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder4.dec4norm1.running_mean: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder4.dec4norm1.running_var: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder4.dec4conv2.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for decoder4.dec4norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder4.dec4norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder4.dec4norm2.running_mean: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder4.dec4norm2.running_var: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for upconv3.weight: copying a param with shape torch.Size([1024, 512, 2, 2]) from checkpoint, the shape in current model is torch.Size([512, 256, 2, 2]).
	size mismatch for upconv3.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder3.dec3conv1.weight: copying a param with shape torch.Size([512, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3]).
	size mismatch for decoder3.dec3norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder3.dec3norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder3.dec3norm1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder3.dec3norm1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder3.dec3conv2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for decoder3.dec3norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder3.dec3norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder3.dec3norm2.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder3.dec3norm2.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for upconv2.weight: copying a param with shape torch.Size([512, 256, 2, 2]) from checkpoint, the shape in current model is torch.Size([256, 128, 2, 2]).
	size mismatch for upconv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder2.dec2conv1.weight: copying a param with shape torch.Size([256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 3, 3]).
	size mismatch for decoder2.dec2norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder2.dec2norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder2.dec2norm1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder2.dec2norm1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder2.dec2conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for decoder2.dec2norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder2.dec2norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder2.dec2norm2.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder2.dec2norm2.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for upconv1.weight: copying a param with shape torch.Size([256, 128, 2, 2]) from checkpoint, the shape in current model is torch.Size([128, 64, 2, 2]).
	size mismatch for upconv1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for decoder1.dec1conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for decoder1.dec1norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for decoder1.dec1norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for decoder1.dec1norm1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for decoder1.dec1norm1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for decoder1.dec1conv2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for decoder1.dec1norm2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for decoder1.dec1norm2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for decoder1.dec1norm2.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for decoder1.dec1norm2.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for conv.1.weight: copying a param with shape torch.Size([1, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 64, 3, 3]).