In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import random
import time
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
import torch.utils.data as data

import argparse
import os
import gzip
import numpy as np
##############################################################################################################
def load_mnist(root='/kaggle/mnist-dataset/data'):
    """
    Load the MNIST dataset for generating training data.
    Adapted to work in Kaggle notebooks.
    """
    path = os.path.join('', '/kaggle/input/mnist-dataset/data/train-images-idx3-ubyte/train-images.idx3-ubyte')
    if not os.path.exists(path):
        raise FileNotFoundError(f"File not found: {path}. Ensure the MNIST dataset is uploaded or available in {root}.")
    
    with open(path, 'rb') as f:
        mnist = np.frombuffer(f.read(), np.uint8, offset=16)
        mnist = mnist.reshape(-1, 28, 28)
    return mnist


def load_fixed_set(root, is_train):
    # Load the fixed dataset
    filename = 'mnist_test_seq.npy'
    path = os.path.join('', "/kaggle/input/mnist-dataset/data/mnist_test_seq.npy")
    dataset = np.load(path)
    dataset = dataset[..., np.newaxis]
    return dataset


class MovingMNIST(data.Dataset):
    def __init__(self, root, is_train=True, n_frames_input=10, n_frames_output=10, num_objects=[2],
                 transform=None):
        '''
        param num_objects: a list of number of possible objects.
        '''
        super(MovingMNIST, self).__init__()

        self.dataset = None
        if is_train:
            self.mnist = load_mnist(root)
        else:
            if num_objects[0] != 2:
                self.mnist = load_mnist(root)
            else:
                self.dataset = load_fixed_set(root, False)
        self.length = int(1e4) if self.dataset is None else self.dataset.shape[1]

        self.is_train = is_train
        self.num_objects = num_objects
        self.n_frames_input = n_frames_input
        self.n_frames_output = n_frames_output
        self.n_frames_total = self.n_frames_input + self.n_frames_output
        self.transform = transform
        # For generating data
        self.image_size_ = 64
        self.digit_size_ = 28
        self.step_length_ = 0.1

    def get_random_trajectory(self, seq_length):
        ''' Generate a random sequence of a MNIST digit '''
        canvas_size = self.image_size_ - self.digit_size_
        x = random.random()
        y = random.random()
        theta = random.random() * 2 * np.pi
        v_y = np.sin(theta)
        v_x = np.cos(theta)

        start_y = np.zeros(seq_length)
        start_x = np.zeros(seq_length)
        for i in range(seq_length):
            # Take a step along velocity.
            y += v_y * self.step_length_
            x += v_x * self.step_length_

            # Bounce off edges.
            if x <= 0:
                x = 0
                v_x = -v_x
            if x >= 1.0:
                x = 1.0
                v_x = -v_x
            if y <= 0:
                y = 0
                v_y = -v_y
            if y >= 1.0:
                y = 1.0
                v_y = -v_y
            start_y[i] = y
            start_x[i] = x

        # Scale to the size of the canvas.
        start_y = (canvas_size * start_y).astype(np.int32)
        start_x = (canvas_size * start_x).astype(np.int32)
        return start_y, start_x

    def generate_moving_mnist(self, num_digits=2):
        '''
        Get random trajectories for the digits and generate a video.
        '''
        data = np.zeros((self.n_frames_total, self.image_size_, self.image_size_), dtype=np.float32)
        for n in range(num_digits):
            # Trajectory
            start_y, start_x = self.get_random_trajectory(self.n_frames_total)
            ind = random.randint(0, self.mnist.shape[0] - 1)
            digit_image = self.mnist[ind]
            for i in range(self.n_frames_total):
                top = start_y[i]
                left = start_x[i]
                bottom = top + self.digit_size_
                right = left + self.digit_size_
                # Draw digit
                data[i, top:bottom, left:right] = np.maximum(data[i, top:bottom, left:right], digit_image)

        data = data[..., np.newaxis]
        return data

    def __getitem__(self, idx):
        length = self.n_frames_input + self.n_frames_output
        if self.is_train or self.num_objects[0] != 2:
            # Sample number of objects
            num_digits = random.choice(self.num_objects)
            # Generate data on the fly
            images = self.generate_moving_mnist(num_digits)
        else:
            images = self.dataset[:, idx, ...]

        # if self.transform is not None:
        #     images = self.transform(images)

        r = 1 # patch size (a 4 dans les PredRNN)
        w = int(64 / r)
        images = images.reshape((length, w, r, w, r)).transpose(0, 2, 4, 1, 3).reshape((length, r * r, w, w))

        input = images[:self.n_frames_input]
        if self.n_frames_output > 0:
            output = images[self.n_frames_input:length]
        else:
            output = []

        frozen = input[-1]
        # add a wall to input data
        # pad = np.zeros_like(input[:, 0])
        # pad[:, 0] = 1
        # pad[:, pad.shape[1] - 1] = 1
        # pad[:, :, 0] = 1
        # pad[:, :, pad.shape[2] - 1] = 1
        #
        # input = np.concatenate((input, np.expand_dims(pad, 1)), 1)

        output = torch.from_numpy(output / 255.0).contiguous().float()
        input = torch.from_numpy(input / 255.0).contiguous().float()
        # print()
        # print(input.size())
        # print(output.size())

        out = [idx,input,output]
        return out

    def __len__(self):
        return self.length

##############################################################################################################
    
from numpy import *
from numpy.linalg import *
from scipy.special import factorial
from functools import reduce
import torch
import torch.nn as nn

from functools import reduce

__all__ = ['M2K','K2M']

def _apply_axis_left_dot(x, mats):
    assert x.dim() == len(mats)+1
    sizex = x.size()
    k = x.dim()-1
    for i in range(k):
        x = tensordot(mats[k-i-1], x, dim=[1,k])
    x = x.permute([k,]+list(range(k))).contiguous()
    x = x.view(sizex)
    return x

def _apply_axis_right_dot(x, mats):
    assert x.dim() == len(mats)+1
    sizex = x.size()
    k = x.dim()-1
    x = x.permute(list(range(1,k+1))+[0,])
    for i in range(k):
        x = tensordot(x, mats[i], dim=[0,0])
    x = x.contiguous()
    x = x.view(sizex)
    return x

class _MK(nn.Module):
    def __init__(self, shape):
        super(_MK, self).__init__()
        self._size = torch.Size(shape)
        self._dim = len(shape)
        M = []
        invM = []
        assert len(shape) > 0
        j = 0
        for l in shape:
            M.append(zeros((l,l)))
            for i in range(l):
                M[-1][i] = ((arange(l)-(l-1)//2)**i)/factorial(i)
            invM.append(inv(M[-1]))
            self.register_buffer('_M'+str(j), torch.from_numpy(M[-1]))
            self.register_buffer('_invM'+str(j), torch.from_numpy(invM[-1]))
            j += 1

    @property
    def M(self):
        return list(self._buffers['_M'+str(j)] for j in range(self.dim()))
    @property
    def invM(self):
        return list(self._buffers['_invM'+str(j)] for j in range(self.dim()))

    def size(self):
        return self._size
    def dim(self):
        return self._dim
    def _packdim(self, x):
        assert x.dim() >= self.dim()
        if x.dim() == self.dim():
            x = x[newaxis,:]
        x = x.contiguous()
        x = x.view([-1,]+list(x.size()[-self.dim():]))
        return x

    def forward(self):
        pass

class M2K(_MK):
    """
    convert moment matrix to convolution kernel
    Arguments:
        shape (tuple of int): kernel shape
    Usage:
        m2k = M2K([5,5])
        m = torch.randn(5,5,dtype=torch.float64)
        k = m2k(m)
    """
    def __init__(self, shape):
        super(M2K, self).__init__(shape)
    def forward(self, m):
        """
        m (Tensor): torch.size=[...,*self.shape]
        """
        sizem = m.size()
        m = self._packdim(m)
        m = _apply_axis_left_dot(m, self.invM)
        m = m.view(sizem)
        return m

class K2M(_MK):
    """
    convert convolution kernel to moment matrix
    Arguments:
        shape (tuple of int): kernel shape
    Usage:
        k2m = K2M([5,5])
        k = torch.randn(5,5,dtype=torch.float64)
        m = k2m(k)
    """
    def __init__(self, shape):
        super(K2M, self).__init__(shape)
    def forward(self, k):
        """
        k (Tensor): torch.size=[...,*self.shape]
        """
        sizek = k.size()
        k = self._packdim(k)
        k = _apply_axis_left_dot(k, self.M)
        k = k.view(sizek)
        return k


    
def tensordot(a,b,dim):
    """
    tensordot in PyTorch, see numpy.tensordot?
    """
    l = lambda x,y:x*y
    if isinstance(dim,int):
        a = a.contiguous()
        b = b.contiguous()
        sizea = a.size()
        sizeb = b.size()
        sizea0 = sizea[:-dim]
        sizea1 = sizea[-dim:]
        sizeb0 = sizeb[:dim]
        sizeb1 = sizeb[dim:]
        N = reduce(l, sizea1, 1)
        assert reduce(l, sizeb0, 1) == N
    else:
        adims = dim[0]
        bdims = dim[1]
        adims = [adims,] if isinstance(adims, int) else adims
        bdims = [bdims,] if isinstance(bdims, int) else bdims
        adims_ = set(range(a.dim())).difference(set(adims))
        adims_ = list(adims_)
        adims_.sort()
        perma = adims_+adims
        bdims_ = set(range(b.dim())).difference(set(bdims))
        bdims_ = list(bdims_)
        bdims_.sort()
        permb = bdims+bdims_
        a = a.permute(*perma).contiguous()
        b = b.permute(*permb).contiguous()

        sizea = a.size()
        sizeb = b.size()
        sizea0 = sizea[:-len(adims)]
        sizea1 = sizea[-len(adims):]
        sizeb0 = sizeb[:len(bdims)]
        sizeb1 = sizeb[len(bdims):]
        N = reduce(l, sizea1, 1)
        assert reduce(l, sizeb0, 1) == N
    a = a.view([-1,N])
    b = b.view([N,-1])
    c = a@b
    return c.view(sizea0+sizeb1)

##############################################################################################################
import torch
import torch.nn as nn

class PhyCell_Cell(nn.Module):
    def __init__(self, input_dim, F_hidden_dim, kernel_size, bias=1):
        super(PhyCell_Cell, self).__init__()
        self.input_dim  = input_dim
        self.F_hidden_dim = F_hidden_dim
        self.kernel_size = kernel_size
        self.padding     = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias
        
        self.F = nn.Sequential()
        self.F.add_module('conv1', nn.Conv2d(in_channels=input_dim, out_channels=F_hidden_dim, kernel_size=self.kernel_size, stride=(1,1), padding=self.padding))
        self.F.add_module('bn1',nn.GroupNorm( 7 ,F_hidden_dim))        
        self.F.add_module('conv2', nn.Conv2d(in_channels=F_hidden_dim, out_channels=input_dim, kernel_size=(1,1), stride=(1,1), padding=(0,0)))

        self.convgate = nn.Conv2d(in_channels=self.input_dim + self.input_dim,
                              out_channels= self.input_dim,
                              kernel_size=(3,3),
                              padding=(1,1), bias=self.bias)

    def forward(self, x, hidden): # x [batch_size, hidden_dim, height, width]      
        combined = torch.cat([x, hidden], dim=1)  # concatenate along channel axis
        combined_conv = self.convgate(combined)
        K = torch.sigmoid(combined_conv)
        hidden_tilde = hidden + self.F(hidden)        # prediction
        next_hidden = hidden_tilde + K * (x-hidden_tilde)   # correction , Haddamard product     
        return next_hidden

class PhyCell(nn.Module):
    def __init__(self, input_shape, input_dim, F_hidden_dims, n_layers, kernel_size, device):
        super(PhyCell, self).__init__()
        self.input_shape = input_shape
        self.input_dim = input_dim
        self.F_hidden_dims = F_hidden_dims
        self.n_layers = n_layers
        self.kernel_size = kernel_size
        self.H = []  
        self.device = device
             
        cell_list = []
        for i in range(0, self.n_layers):
            cell_list.append(PhyCell_Cell(input_dim=input_dim,
                                          F_hidden_dim=self.F_hidden_dims[i],
                                          kernel_size=self.kernel_size))                                     
        self.cell_list = nn.ModuleList(cell_list)
        
       
    def forward(self, input_, first_timestep=False): # input_ [batch_size, 1, channels, width, height]    
        batch_size = input_.data.size()[0]
        if (first_timestep):   
            self.initHidden(batch_size) # init Hidden at each forward start
              
        for j,cell in enumerate(self.cell_list):
            if j==0: # bottom layer
                self.H[j] = cell(input_, self.H[j])
            else:
                self.H[j] = cell(self.H[j-1],self.H[j])
        
        return self.H , self.H 
    
    def initHidden(self,batch_size):
        self.H = [] 
        for i in range(self.n_layers):
            self.H.append( torch.zeros(batch_size, self.input_dim, self.input_shape[0], self.input_shape[1]).to(self.device) )

    def setHidden(self, H):
        self.H = H

        
class ConvLSTM_Cell(nn.Module):
    def __init__(self, input_shape, input_dim, hidden_dim, kernel_size, bias=1):              
        """
        input_shape: (int, int)
            Height and width of input tensor as (height, width).
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """
        super(ConvLSTM_Cell, self).__init__()
        
        self.height, self.width = input_shape
        self.input_dim  = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding     = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias        = bias
        
        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding, bias=self.bias)
                 
    # we implement LSTM that process only one timestep 
    def forward(self,x, hidden): # x [batch, hidden_dim, width, height]          
        h_cur, c_cur = hidden
        
        combined = torch.cat([x, h_cur], dim=1)  # concatenate along channel axis
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next


class ConvLSTM(nn.Module):
    def __init__(self, input_shape, input_dim, hidden_dims, n_layers, kernel_size,device):
        super(ConvLSTM, self).__init__()
        self.input_shape = input_shape
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.n_layers = n_layers
        self.kernel_size = kernel_size
        self.H, self.C = [],[]   
        self.device = device
        
        cell_list = []
        for i in range(0, self.n_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dims[i-1]
            print('layer ',i,'input dim ', cur_input_dim, ' hidden dim ', self.hidden_dims[i])
            cell_list.append(ConvLSTM_Cell(input_shape=self.input_shape,
                                          input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dims[i],
                                          kernel_size=self.kernel_size))                                     
        self.cell_list = nn.ModuleList(cell_list)
        
       
    def forward(self, input_, first_timestep=False): # input_ [batch_size, 1, channels, width, height]    
        batch_size = input_.data.size()[0]
        if (first_timestep):   
            self.initHidden(batch_size) # init Hidden at each forward start
              
        for j,cell in enumerate(self.cell_list):
            if j==0: # bottom layer
                self.H[j], self.C[j] = cell(input_, (self.H[j],self.C[j]))
            else:
                self.H[j], self.C[j] = cell(self.H[j-1],(self.H[j],self.C[j]))
        
        return (self.H,self.C) , self.H   # (hidden, output)
    
    def initHidden(self,batch_size):
        self.H, self.C = [],[]  
        for i in range(self.n_layers):
            self.H.append( torch.zeros(batch_size,self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device) )
            self.C.append( torch.zeros(batch_size,self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device) )
    
    def setHidden(self, hidden):
        H,C = hidden
        self.H, self.C = H,C
 

class dcgan_conv(nn.Module):
    def __init__(self, nin, nout, stride):
        super(dcgan_conv, self).__init__()
        self.main = nn.Sequential(
                nn.Conv2d(in_channels=nin, out_channels=nout, kernel_size=(3,3), stride=stride, padding=1),
                nn.GroupNorm(16,nout),
                nn.LeakyReLU(0.2, inplace=True),
                )

    def forward(self, input):
        return self.main(input)

class dcgan_upconv(nn.Module):
    def __init__(self, nin, nout, stride):
        super(dcgan_upconv, self).__init__()
        if (stride ==2):
            output_padding = 1
        else:
            output_padding = 0
        self.main = nn.Sequential(
                nn.ConvTranspose2d(in_channels=nin,out_channels=nout,kernel_size=(3,3), stride=stride,padding=1,output_padding=output_padding),
                nn.GroupNorm(16,nout),
                nn.LeakyReLU(0.2, inplace=True),
                )

    def forward(self, input):
        return self.main(input)
        
class encoder_E(nn.Module):
    def __init__(self, nc=1, nf=32):
        super(encoder_E, self).__init__()
        # input is (1) x 64 x 64
        self.c1 = dcgan_conv(nc, nf, stride=2) # (32) x 32 x 32
        self.c2 = dcgan_conv(nf, nf, stride=1) # (32) x 32 x 32
        self.c3 = dcgan_conv(nf, 2*nf, stride=2) # (64) x 16 x 16

    def forward(self, input):
        h1 = self.c1(input)  
        h2 = self.c2(h1)    
        h3 = self.c3(h2)
        return h3

class decoder_D(nn.Module):
    def __init__(self, nc=1, nf=32):
        super(decoder_D, self).__init__()
        self.upc1 = dcgan_upconv(2*nf, nf, stride=2) #(32) x 32 x 32
        self.upc2 = dcgan_upconv(nf, nf, stride=1) #(32) x 32 x 32
        self.upc3 = nn.ConvTranspose2d(in_channels=nf,out_channels=nc,kernel_size=(3,3),stride=2,padding=1,output_padding=1)  #(nc) x 64 x 64

    def forward(self, input):      
        d1 = self.upc1(input) 
        d2 = self.upc2(d1)
        d3 = self.upc3(d2)  
        return d3  


class encoder_specific(nn.Module):
    def __init__(self, nc=64, nf=64):
        super(encoder_specific, self).__init__()
        self.c1 = dcgan_conv(nc, nf, stride=1) # (64) x 16 x 16
        self.c2 = dcgan_conv(nf, nf, stride=1) # (64) x 16 x 16

    def forward(self, input):
        h1 = self.c1(input)  
        h2 = self.c2(h1)     
        return h2

class decoder_specific(nn.Module):
    def __init__(self, nc=64, nf=64):
        super(decoder_specific, self).__init__()
        self.upc1 = dcgan_upconv(nf, nf, stride=1) #(64) x 16 x 16
        self.upc2 = dcgan_upconv(nf, nc, stride=1) #(32) x 32 x 32
        
    def forward(self, input):
        d1 = self.upc1(input) 
        d2 = self.upc2(d1)  
        return d2       

        
class EncoderRNN(torch.nn.Module):
    def __init__(self,phycell,convcell, device):
        super(EncoderRNN, self).__init__()
        self.encoder_E = encoder_E()   # general encoder 64x64x1 -> 32x32x32
        self.encoder_Ep = encoder_specific() # specific image encoder 32x32x32 -> 16x16x64
        self.encoder_Er = encoder_specific() 
        self.decoder_Dp = decoder_specific() # specific image decoder 16x16x64 -> 32x32x32 
        self.decoder_Dr = decoder_specific()     
        self.decoder_D = decoder_D()  # general decoder 32x32x32 -> 64x64x1 

        self.encoder_E = self.encoder_E.to(device)
        self.encoder_Ep = self.encoder_Ep.to(device) 
        self.encoder_Er = self.encoder_Er.to(device) 
        self.decoder_Dp = self.decoder_Dp.to(device) 
        self.decoder_Dr = self.decoder_Dr.to(device)               
        self.decoder_D = self.decoder_D.to(device)
        self.phycell = phycell.to(device)
        self.convcell = convcell.to(device)

    def forward(self, input, first_timestep=False, decoding=False):
        input = self.encoder_E(input) # general encoder 64x64x1 -> 32x32x32
    
        if decoding:  # input=None in decoding phase
            input_phys = None
        else:
            input_phys = self.encoder_Ep(input)
        input_conv = self.encoder_Er(input)     

        hidden1, output1 = self.phycell(input_phys, first_timestep)
        hidden2, output2 = self.convcell(input_conv, first_timestep)

        decoded_Dp = self.decoder_Dp(output1[-1])
        decoded_Dr = self.decoder_Dr(output2[-1])
        
        out_phys = torch.sigmoid(self.decoder_D(decoded_Dp)) # partial reconstructions for vizualization
        out_conv = torch.sigmoid(self.decoder_D(decoded_Dr))

        concat = decoded_Dp + decoded_Dr   
        output_image = torch.sigmoid( self.decoder_D(concat ))
        return out_phys, hidden1, output_image, out_phys, out_conv        

##############################################################################################################

device = torch.device("cuda")

args = {}
args['root'] = '/kaggle/input/mnist_dataset/data'

mm = MovingMNIST(root=args['root'], is_train=True, n_frames_input=10, n_frames_output=10, num_objects=[2])
train_loader = torch.utils.data.DataLoader(dataset=mm, batch_size=64, shuffle=True, num_workers=0)

mm = MovingMNIST(root=args['root'], is_train=False, n_frames_input=10, n_frames_output=10, num_objects=[2])
test_loader = torch.utils.data.DataLoader(dataset=mm, batch_size=64, shuffle=False, num_workers=0)

constraints = torch.zeros((49,7,7)).to(device)
ind = 0
for i in range(0,7):
    for j in range(0,7):
        constraints[ind,i,j] = 1
        ind +=1    

def train_on_batch(input_tensor, target_tensor, encoder, encoder_optimizer, criterion,teacher_forcing_ratio):                
    encoder_optimizer.zero_grad()
    # input_tensor : torch.Size([batch_size, input_length, channels, cols, rows])
    input_length  = input_tensor.size(1)
    target_length = target_tensor.size(1)
    loss = 0
    for ei in range(input_length-1): 
        encoder_output, encoder_hidden, output_image,_,_ = encoder(input_tensor[:,ei,:,:,:], (ei==0) )
        loss += criterion(output_image,input_tensor[:,ei+1,:,:,:])

    decoder_input = input_tensor[:,-1,:,:,:] # first decoder input = last image of input sequence
    
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False 
    for di in range(target_length):
        decoder_output, decoder_hidden, output_image,_,_ = encoder(decoder_input)
        target = target_tensor[:,di,:,:,:]
        loss += criterion(output_image,target)
        if use_teacher_forcing:
            decoder_input = target # Teacher forcing    
        else:
            decoder_input = output_image

    ## Moment regularization  # encoder.phycell.cell_list[0].F.conv1.weight # size (nb_filters,in_channels,7,7)
    #k2m = K2M([7,7]).to(device)
    #for b in range(0,encoder.phycell.cell_list[0].input_dim):
    #    filters = encoder.phycell.cell_list[0].F.conv1.weight[:,b,:,:] # (nb_filters,7,7)     
    #    m = k2m(filters.double()) 
    #    m  = m.float()   
    #    loss += criterion(m, constraints) # constrains is a precomputed matrix   
    loss.backward()
    encoder_optimizer.step()
    return loss.item() / target_length


def trainIters(encoder, nepochs, print_every=10,eval_every=10,name=''):
    train_losses = []
    best_mse = float('inf')

    encoder_optimizer = torch.optim.Adam(encoder.parameters(),lr=0.001)
    scheduler_enc = ReduceLROnPlateau(encoder_optimizer, mode='min', patience=2,factor=0.1,verbose=True)
    criterion = nn.MSELoss()
    
    for epoch in tqdm(range(0, nepochs)):        
        t0 = time.time()
        loss_epoch = 0
        teacher_forcing_ratio = np.maximum(0 , 1 - epoch * 0.003) 
        
        for i, out in enumerate(train_loader, 0):
            input_tensor = out[1].to(device)
            target_tensor = out[2].to(device)
            loss = train_on_batch(input_tensor, target_tensor, encoder, encoder_optimizer, criterion, teacher_forcing_ratio)                                   
            loss_epoch += loss
                      
        train_losses.append(loss_epoch)        
        if (epoch+1) % print_every == 0:
            print('epoch ',epoch,  ' loss ',loss_epoch, ' time epoch ',time.time()-t0)
            
        if (epoch+1) % eval_every == 0:
            mse, mae,ssim = evaluate(encoder,test_loader) 
            scheduler_enc.step(mse)                   
            torch.save(encoder.state_dict(),'encoder_{}.pth'.format(name))                           
    return train_losses

    
def evaluate(encoder,loader):
    total_mse, total_mae,total_ssim,total_bce = 0,0,0,0
    t0 = time.time()
    with torch.no_grad():
        for i, out in enumerate(loader, 0):
            input_tensor = out[1].to(device)
            target_tensor = out[2].to(device)
            input_length = input_tensor.size()[1]
            target_length = target_tensor.size()[1]

            for ei in range(input_length-1):
                encoder_output, encoder_hidden, _,_,_  = encoder(input_tensor[:,ei,:,:,:], (ei==0))

            decoder_input = input_tensor[:,-1,:,:,:] # first decoder input= last image of input sequence
            predictions = []

            for di in range(target_length):
                decoder_output, decoder_hidden, output_image,_,_ = encoder(decoder_input, False, False)
                decoder_input = output_image
                predictions.append(output_image.cpu())

            input = input_tensor.cpu().numpy()
            target = target_tensor.cpu().numpy()
            predictions =  np.stack(predictions) # (10, batch_size, 1, 64, 64)
            predictions = predictions.swapaxes(0,1)  # (batch_size,10, 1, 64, 64)

            mse_batch = np.mean((predictions-target)**2 , axis=(0,1,2)).sum()
            mae_batch = np.mean(np.abs(predictions-target) ,  axis=(0,1,2)).sum() 
            total_mse += mse_batch
            total_mae += mae_batch
            for a in range(0,target.shape[0]):
                for b in range(0,target.shape[1]):
                    total_ssim += ssim(target[a,b,0,], predictions[a,b,0,], data_range=1.0) / (target.shape[0]*target.shape[1]) 

            
            cross_entropy = -target*np.log(predictions) - (1-target) * np.log(1-predictions)
            cross_entropy = cross_entropy.sum()
            cross_entropy = cross_entropy / (64*target_length)
            total_bce +=  cross_entropy
     
    print('eval mse ', total_mse/len(loader),  ' eval mae ', total_mae/len(loader),' eval ssim ',total_ssim/len(loader), ' time= ', time.time()-t0)        
    return total_mse/len(loader),  total_mae/len(loader), total_ssim/len(loader)


# phycell  =  PhyCell(input_shape=(16,16), input_dim=64, F_hidden_dims=[49], n_layers=1, kernel_size=(7,7), device=device) 
convcell_0 =  ConvLSTM(input_shape=(16,16), input_dim=64, hidden_dims=[128,128,64], n_layers=3, kernel_size=(3,3), device=device)   
convcell =  ConvLSTM(input_shape=(16,16), input_dim=64, hidden_dims=[128,128,64], n_layers=3, kernel_size=(3,3), device=device)   
encoder  = EncoderRNN(convcell_0, convcell, device)
  
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
   
print('phycell ' , count_parameters(convcell_0) )    
print('convcell ' , count_parameters(convcell) ) 
print('encoder ' , count_parameters(encoder) ) 

trainIters(encoder,250,print_every=1,eval_every=10,name="phydnet_double")

#encoder.load_state_dict(torch.load('save/encoder_phydnet.pth'))
#encoder.eval()
#mse, mae,ssim = evaluate(encoder,test_loader) 
torch.save(encoder.state_dict(),'encoder_{}.pth'.format('phydnet_2_lstm_250'))


  return sum(p.numel() for p in model.parameters() if p.requires_grad)


layer  0 input dim  64  hidden dim  128
layer  1 input dim  128  hidden dim  128
layer  2 input dim  128  hidden dim  64
layer  0 input dim  64  hidden dim  128
layer  1 input dim  128  hidden dim  128
layer  2 input dim  128  hidden dim  64
phycell  2508032
convcell  2508032
encoder  5368961


  0%|          | 1/250 [02:36<10:47:40, 156.07s/it]

epoch  0  loss  10.765596222877496  time epoch  156.06645011901855


  1%|          | 2/250 [05:11<10:42:21, 155.41s/it]

epoch  1  loss  5.043391272425652  time epoch  154.94752144813538


  1%|          | 3/250 [07:46<10:38:59, 155.22s/it]

epoch  2  loss  4.9797787562012665  time epoch  154.99199414253235


  2%|▏         | 4/250 [10:21<10:36:03, 155.14s/it]

epoch  3  loss  5.564058706164364  time epoch  155.00773119926453


  2%|▏         | 5/250 [12:56<10:33:18, 155.10s/it]

epoch  4  loss  5.156459802389145  time epoch  155.0195882320404


  2%|▏         | 6/250 [15:31<10:30:33, 155.06s/it]

epoch  5  loss  3.814295329153539  time epoch  154.98141622543335


  3%|▎         | 7/250 [18:05<10:27:48, 155.01s/it]

epoch  6  loss  3.728359360992909  time epoch  154.92336916923523


  3%|▎         | 8/250 [20:40<10:25:08, 154.99s/it]

epoch  7  loss  3.5967000156641  time epoch  154.94834566116333


  4%|▎         | 9/250 [23:15<10:22:24, 154.96s/it]

epoch  8  loss  3.2844294771552107  time epoch  154.86916637420654
epoch  9  loss  2.9681670844554886  time epoch  154.9164481163025


  4%|▍         | 10/250 [27:26<12:17:51, 184.46s/it]

eval mse  96.10664557195773  eval mae  178.66266209304712  eval ssim  0.7133131761697616  time=  95.5896258354187


  4%|▍         | 11/250 [30:00<11:37:43, 175.16s/it]

epoch  10  loss  3.3678231000900274  time epoch  154.0631628036499


  5%|▍         | 12/250 [32:34<11:09:26, 168.77s/it]

epoch  11  loss  3.021533003449438  time epoch  154.14252614974976


  5%|▌         | 13/250 [35:08<10:49:05, 164.33s/it]

epoch  12  loss  2.829814699292182  time epoch  154.11076831817627


  6%|▌         | 14/250 [37:42<10:34:11, 161.24s/it]

epoch  13  loss  2.7279561430215837  time epoch  154.09221363067627


  6%|▌         | 15/250 [40:16<10:23:10, 159.11s/it]

epoch  14  loss  2.883279581367969  time epoch  154.1674120426178


  6%|▋         | 16/250 [42:51<10:14:46, 157.63s/it]

epoch  15  loss  2.6785205066204076  time epoch  154.21034336090088


  7%|▋         | 17/250 [45:25<10:08:15, 156.63s/it]

epoch  16  loss  2.897063842415809  time epoch  154.30762577056885


  7%|▋         | 18/250 [47:59<10:02:47, 155.89s/it]

epoch  17  loss  2.6677572593092918  time epoch  154.1686351299286


  8%|▊         | 19/250 [50:33<9:58:09, 155.37s/it] 

epoch  18  loss  2.748244440555572  time epoch  154.13706183433533
epoch  19  loss  2.8449601411819474  time epoch  154.2375693321228


  8%|▊         | 20/250 [54:43<11:44:18, 183.73s/it]

eval mse  76.83205904626543  eval mae  161.8645928255312  eval ssim  0.7760415819105716  time=  95.55468654632568


  8%|▊         | 21/250 [57:17<11:07:39, 174.93s/it]

epoch  20  loss  2.6231693416833872  time epoch  154.41851043701172


  9%|▉         | 22/250 [59:52<10:41:11, 168.73s/it]

epoch  21  loss  2.6266471117734898  time epoch  154.2682590484619


  9%|▉         | 23/250 [1:02:26<10:21:58, 164.40s/it]

epoch  22  loss  2.6733356058597564  time epoch  154.28408980369568


 10%|▉         | 24/250 [1:05:00<10:07:48, 161.37s/it]

epoch  23  loss  2.6806959584355377  time epoch  154.2967324256897


 10%|█         | 25/250 [1:07:34<9:56:57, 159.19s/it] 

epoch  24  loss  2.4937788575887687  time epoch  154.09974718093872


 10%|█         | 26/250 [1:10:09<9:48:41, 157.69s/it]

epoch  25  loss  2.439134562015533  time epoch  154.18627166748047


 11%|█         | 27/250 [1:12:43<9:42:23, 156.70s/it]

epoch  26  loss  2.424879378080368  time epoch  154.382159948349


 11%|█         | 28/250 [1:15:17<9:37:11, 156.00s/it]

epoch  27  loss  2.5841910392045966  time epoch  154.3720998764038


 12%|█▏        | 29/250 [1:17:52<9:32:49, 155.52s/it]

epoch  28  loss  2.3775351390242587  time epoch  154.3955819606781
epoch  29  loss  2.4528605833649633  time epoch  154.44626998901367


 12%|█▏        | 30/250 [1:22:02<11:14:38, 183.99s/it]

eval mse  110.46768339120658  eval mae  239.15840110049885  eval ssim  0.6951322642886013  time=  95.93215036392212


 12%|█▏        | 31/250 [1:24:36<10:39:00, 175.07s/it]

epoch  30  loss  2.602109389007091  time epoch  154.25119042396545


 13%|█▎        | 32/250 [1:27:11<10:13:20, 168.81s/it]

epoch  31  loss  2.4333908282220365  time epoch  154.20679354667664


 13%|█▎        | 33/250 [1:29:45<9:54:42, 164.43s/it] 

epoch  32  loss  2.2732908040285102  time epoch  154.2206473350525


 14%|█▎        | 34/250 [1:32:19<9:40:39, 161.29s/it]

epoch  33  loss  2.4441874206066143  time epoch  153.95794534683228


 14%|█▍        | 35/250 [1:34:53<9:30:12, 159.13s/it]

epoch  34  loss  2.523072811961173  time epoch  154.0811631679535


 14%|█▍        | 36/250 [1:37:27<9:22:10, 157.62s/it]

epoch  35  loss  2.343582481145858  time epoch  154.0991780757904


 15%|█▍        | 37/250 [1:40:01<9:15:55, 156.60s/it]

epoch  36  loss  2.301986705511809  time epoch  154.21272706985474


 15%|█▌        | 38/250 [1:42:35<9:10:33, 155.82s/it]

epoch  37  loss  2.2010162308812142  time epoch  154.0022053718567


 16%|█▌        | 39/250 [1:45:09<9:06:13, 155.32s/it]

epoch  38  loss  2.3357132486999044  time epoch  154.1591796875
epoch  39  loss  2.4175549976527684  time epoch  154.09891629219055


 16%|█▌        | 40/250 [1:49:19<10:42:11, 183.48s/it]

eval mse  60.01266603409105  eval mae  147.79510604955587  eval ssim  0.7895562362272606  time=  95.03133845329285


 16%|█▋        | 41/250 [1:51:53<10:08:24, 174.66s/it]

epoch  40  loss  2.295700508356094  time epoch  154.07976651191711


 17%|█▋        | 42/250 [1:54:27<9:44:01, 168.47s/it] 

epoch  41  loss  2.3161981567740435  time epoch  154.02196717262268


 17%|█▋        | 43/250 [1:57:01<9:26:19, 164.15s/it]

epoch  42  loss  2.2430729389190667  time epoch  154.077623128891


 18%|█▊        | 44/250 [1:59:35<9:13:17, 161.15s/it]

epoch  43  loss  2.210802635550499  time epoch  154.15702629089355


 18%|█▊        | 45/250 [2:02:09<9:03:27, 159.06s/it]

epoch  44  loss  2.2469985634088516  time epoch  154.18285059928894


 18%|█▊        | 46/250 [2:04:43<8:55:58, 157.64s/it]

epoch  45  loss  2.3616086095571522  time epoch  154.3137457370758


 19%|█▉        | 47/250 [2:07:18<8:49:45, 156.58s/it]

epoch  46  loss  2.324525779485702  time epoch  154.10176730155945


 19%|█▉        | 48/250 [2:09:52<8:44:50, 155.89s/it]

epoch  47  loss  2.4105421259999265  time epoch  154.2941210269928


 20%|█▉        | 49/250 [2:12:26<8:40:35, 155.40s/it]

epoch  48  loss  2.2854656562209126  time epoch  154.24727606773376
epoch  49  loss  2.2494936659932137  time epoch  154.26011109352112


 20%|██        | 50/250 [2:16:36<10:12:12, 183.66s/it]

eval mse  58.462623183135015  eval mae  145.69382802392266  eval ssim  0.7789204231036121  time=  95.30478096008301


 20%|██        | 51/250 [2:19:10<9:39:51, 174.83s/it] 

epoch  50  loss  2.3277563154697427  time epoch  154.2204282283783


 21%|██        | 52/250 [2:21:44<9:16:25, 168.61s/it]

epoch  51  loss  2.3192424401640888  time epoch  154.10276746749878


 21%|██        | 53/250 [2:24:18<8:59:26, 164.30s/it]

epoch  52  loss  2.0508049510419366  time epoch  154.22624921798706


 22%|██▏       | 54/250 [2:26:52<8:46:41, 161.23s/it]

epoch  53  loss  2.0693713046610362  time epoch  154.0824191570282


 22%|██▏       | 55/250 [2:29:27<8:37:09, 159.12s/it]

epoch  54  loss  2.214559970796108  time epoch  154.20056295394897


 22%|██▏       | 56/250 [2:32:01<8:29:36, 157.61s/it]

epoch  55  loss  2.3331480957567705  time epoch  154.0734932422638


 23%|██▎       | 57/250 [2:34:35<8:23:34, 156.55s/it]

epoch  56  loss  2.207229299098254  time epoch  154.08965492248535


 23%|██▎       | 58/250 [2:37:09<8:18:35, 155.81s/it]

epoch  57  loss  2.1762430869042864  time epoch  154.07520723342896


 24%|██▎       | 59/250 [2:39:43<8:14:15, 155.27s/it]

epoch  58  loss  2.1160259217023842  time epoch  153.9924054145813
epoch  59  loss  2.2828100219368928  time epoch  154.15560388565063


 24%|██▍       | 60/250 [2:43:52<9:40:50, 183.42s/it]

eval mse  55.06485976686903  eval mae  135.10643340675693  eval ssim  0.8107695595164642  time=  94.90162682533264


 24%|██▍       | 61/250 [2:46:26<9:10:06, 174.64s/it]

epoch  60  loss  2.270796018093826  time epoch  154.13292932510376


 25%|██▍       | 62/250 [2:49:00<8:47:50, 168.46s/it]

epoch  61  loss  2.0628976166248316  time epoch  154.04931116104126


 25%|██▌       | 63/250 [2:51:34<8:31:36, 164.15s/it]

epoch  62  loss  2.260648702085017  time epoch  154.1002335548401


 26%|██▌       | 64/250 [2:54:08<8:19:33, 161.15s/it]

epoch  63  loss  2.0412716321647157  time epoch  154.13766074180603


 26%|██▌       | 65/250 [2:56:42<8:10:11, 158.98s/it]

epoch  64  loss  2.3358090750873095  time epoch  153.92979788780212


 26%|██▋       | 66/250 [2:59:16<8:03:04, 157.53s/it]

epoch  65  loss  2.2434452287852746  time epoch  154.12460899353027


 27%|██▋       | 67/250 [3:01:50<7:57:10, 156.45s/it]

epoch  66  loss  2.1413798488676545  time epoch  153.9338936805725


 27%|██▋       | 68/250 [3:04:24<7:52:16, 155.70s/it]

epoch  67  loss  2.1947723232209677  time epoch  153.93936467170715


 28%|██▊       | 69/250 [3:06:58<7:48:05, 155.17s/it]

epoch  68  loss  2.096797198802232  time epoch  153.92567920684814
epoch  69  loss  2.18074422776699  time epoch  153.98608827590942


 28%|██▊       | 70/250 [3:11:07<9:09:43, 183.24s/it]

eval mse  57.43643829928842  eval mae  140.98332399015973  eval ssim  0.7941804915558367  time=  94.71760153770447


 28%|██▊       | 71/250 [3:13:41<8:40:27, 174.45s/it]

epoch  70  loss  2.3014312691986576  time epoch  153.94884705543518


 29%|██▉       | 72/250 [3:16:15<8:19:25, 168.35s/it]

epoch  71  loss  2.0802742041647435  time epoch  154.0964126586914


 29%|██▉       | 73/250 [3:18:49<8:03:49, 164.01s/it]

epoch  72  loss  2.202663320302963  time epoch  153.88745832443237


 30%|██▉       | 74/250 [3:21:23<7:52:10, 160.97s/it]

epoch  73  loss  2.130009695142508  time epoch  153.8657205104828


 30%|███       | 75/250 [3:23:57<7:43:18, 158.85s/it]

epoch  74  loss  2.154022809118032  time epoch  153.90671682357788


 30%|███       | 76/250 [3:26:31<7:36:22, 157.37s/it]

epoch  75  loss  2.1695988453924655  time epoch  153.92057299613953


 31%|███       | 77/250 [3:29:04<7:30:40, 156.30s/it]

epoch  76  loss  2.161696970462799  time epoch  153.80883145332336


 31%|███       | 78/250 [3:31:38<7:26:09, 155.63s/it]

epoch  77  loss  2.1503468886017805  time epoch  154.07261681556702


 32%|███▏      | 79/250 [3:34:12<7:22:08, 155.14s/it]

epoch  78  loss  2.1558426231145855  time epoch  153.97698664665222
epoch  79  loss  2.289271964877845  time epoch  154.01446342468262


 32%|███▏      | 80/250 [3:38:21<8:39:09, 183.23s/it]

eval mse  62.23111746721207  eval mae  145.85428332675036  eval ssim  0.8140915399363965  time=  94.72130751609802


 32%|███▏      | 81/250 [3:40:55<8:11:14, 174.40s/it]

epoch  80  loss  2.065720968693495  time epoch  153.80037021636963


 33%|███▎      | 82/250 [3:43:29<7:51:17, 168.32s/it]

epoch  81  loss  2.044433735311031  time epoch  154.10995364189148


 33%|███▎      | 83/250 [3:46:03<7:36:38, 164.07s/it]

epoch  82  loss  2.1814938157796866  time epoch  154.14720821380615


 34%|███▎      | 84/250 [3:48:37<7:25:36, 161.06s/it]

epoch  83  loss  2.084453571587801  time epoch  154.0555272102356


 34%|███▍      | 85/250 [3:51:12<7:17:17, 159.01s/it]

epoch  84  loss  2.2213664904236787  time epoch  154.23101830482483


 34%|███▍      | 86/250 [3:53:46<7:10:38, 157.55s/it]

epoch  85  loss  2.1154447630047786  time epoch  154.1457805633545


 35%|███▍      | 87/250 [3:56:20<7:05:13, 156.53s/it]

epoch  86  loss  2.127197992801666  time epoch  154.12438797950745


 35%|███▌      | 88/250 [3:58:54<7:00:33, 155.76s/it]

epoch  87  loss  2.2006409034132957  time epoch  153.98418164253235


 36%|███▌      | 89/250 [4:01:28<6:56:34, 155.25s/it]

epoch  88  loss  2.159804455935955  time epoch  154.0423800945282
epoch  89  loss  2.2830202013254155  time epoch  154.2106466293335


 36%|███▌      | 90/250 [4:05:37<8:09:15, 183.47s/it]

eval mse  51.94958486374776  eval mae  138.63649919837903  eval ssim  0.8208392223007771  time=  95.07417297363281


 36%|███▋      | 91/250 [4:08:11<7:42:43, 174.61s/it]

epoch  90  loss  2.156074836850167  time epoch  153.93754315376282


 37%|███▋      | 92/250 [4:10:45<7:23:29, 168.41s/it]

epoch  91  loss  2.1371904097497456  time epoch  153.9443497657776


 37%|███▋      | 93/250 [4:13:19<7:09:18, 164.07s/it]

epoch  92  loss  2.1699482567608364  time epoch  153.92797875404358


 38%|███▊      | 94/250 [4:15:53<6:58:41, 161.04s/it]

epoch  93  loss  2.098791955411434  time epoch  153.96558260917664


 38%|███▊      | 95/250 [4:18:27<6:50:32, 158.92s/it]

epoch  94  loss  2.106539896130562  time epoch  153.97365880012512


 38%|███▊      | 96/250 [4:21:01<6:44:05, 157.44s/it]

epoch  95  loss  2.1054317548871038  time epoch  153.97921919822693


 39%|███▉      | 97/250 [4:23:35<6:39:07, 156.52s/it]

epoch  96  loss  2.0825126633048057  time epoch  154.36880493164062


 39%|███▉      | 98/250 [4:26:09<6:34:34, 155.76s/it]

epoch  97  loss  2.1325641356408593  time epoch  153.9782910346985


 40%|███▉      | 99/250 [4:28:43<6:30:47, 155.28s/it]

epoch  98  loss  2.156867976486682  time epoch  154.16661167144775
epoch  99  loss  2.1055562674999253  time epoch  153.9424638748169


 40%|████      | 100/250 [4:32:52<7:38:08, 183.26s/it]

eval mse  47.203501950403684  eval mae  126.1799594854853  eval ssim  0.841168515327131  time=  94.54628109931946


 40%|████      | 101/250 [4:35:26<7:13:18, 174.49s/it]

epoch  100  loss  2.058933539688587  time epoch  154.0237135887146


 41%|████      | 102/250 [4:38:00<6:55:19, 168.37s/it]

epoch  101  loss  2.1272852063179015  time epoch  154.0994837284088


 41%|████      | 103/250 [4:40:34<6:42:05, 164.12s/it]

epoch  102  loss  2.1066422753036016  time epoch  154.18771290779114


 42%|████▏     | 104/250 [4:43:08<6:32:01, 161.10s/it]

epoch  103  loss  2.0911260783672327  time epoch  154.07209587097168


 42%|████▏     | 105/250 [4:45:43<6:24:17, 159.01s/it]

epoch  104  loss  2.1338419362902625  time epoch  154.1379792690277


 42%|████▏     | 106/250 [4:48:17<6:18:12, 157.59s/it]

epoch  105  loss  2.1350570678710934  time epoch  154.2603714466095


 43%|████▎     | 107/250 [4:50:51<6:13:03, 156.53s/it]

epoch  106  loss  2.1662178754806507  time epoch  154.05842876434326


 43%|████▎     | 108/250 [4:53:25<6:09:00, 155.92s/it]

epoch  107  loss  2.0567799635231494  time epoch  154.49979758262634


 44%|████▎     | 109/250 [4:56:00<6:05:13, 155.42s/it]

epoch  108  loss  2.123118927329779  time epoch  154.24047207832336
epoch  109  loss  2.0616489619016636  time epoch  154.21527099609375


 44%|████▍     | 110/250 [5:00:09<7:08:33, 183.67s/it]

eval mse  45.005133003186266  eval mae  123.49883231387776  eval ssim  0.8541310178151542  time=  95.3281569480896


 44%|████▍     | 111/250 [5:02:43<6:44:56, 174.80s/it]

epoch  110  loss  2.0689434200525283  time epoch  154.09276008605957


 45%|████▍     | 112/250 [5:05:17<6:27:47, 168.60s/it]

epoch  111  loss  2.0711900465190407  time epoch  154.15204906463623


 45%|████▌     | 113/250 [5:07:52<6:15:06, 164.28s/it]

epoch  112  loss  2.099899508804082  time epoch  154.19250750541687


 46%|████▌     | 114/250 [5:10:26<6:05:29, 161.25s/it]

epoch  113  loss  2.0760511219501483  time epoch  154.16448640823364


 46%|████▌     | 115/250 [5:13:00<5:57:58, 159.10s/it]

epoch  114  loss  2.0762107864022257  time epoch  154.08713293075562


 46%|████▋     | 116/250 [5:15:34<5:52:07, 157.67s/it]

epoch  115  loss  2.1311343252658848  time epoch  154.3263533115387


 47%|████▋     | 117/250 [5:18:08<5:47:10, 156.62s/it]

epoch  116  loss  2.0818642579019073  time epoch  154.1817238330841


 47%|████▋     | 118/250 [5:20:43<5:42:54, 155.86s/it]

epoch  117  loss  2.1706783160567285  time epoch  154.09266662597656


 48%|████▊     | 119/250 [5:23:17<5:39:11, 155.36s/it]

epoch  118  loss  2.134655776619911  time epoch  154.1762797832489
epoch  119  loss  2.1603552058339113  time epoch  154.08913922309875


 48%|████▊     | 120/250 [5:27:26<6:37:45, 183.58s/it]

eval mse  42.36255521835036  eval mae  114.07319752881482  eval ssim  0.8704364449142262  time=  95.28582954406738


 48%|████▊     | 121/250 [5:30:00<6:15:40, 174.73s/it]

epoch  120  loss  2.101302557438612  time epoch  154.0915515422821


 49%|████▉     | 122/250 [5:32:34<5:59:30, 168.52s/it]

epoch  121  loss  2.1089185394346712  time epoch  154.0326702594757


 49%|████▉     | 123/250 [5:35:08<5:47:27, 164.15s/it]

epoch  122  loss  2.0905405774712555  time epoch  153.9610002040863


 50%|████▉     | 124/250 [5:37:42<5:38:23, 161.14s/it]

epoch  123  loss  2.2026514723896975  time epoch  154.1090865135193


 50%|█████     | 125/250 [5:40:17<5:31:23, 159.07s/it]

epoch  124  loss  2.0801022566854943  time epoch  154.23720622062683


 50%|█████     | 126/250 [5:42:51<5:25:37, 157.56s/it]

epoch  125  loss  2.1514226995408547  time epoch  154.02758932113647


 51%|█████     | 127/250 [5:45:25<5:20:48, 156.50s/it]

epoch  126  loss  2.1502346791326996  time epoch  154.01581931114197


 51%|█████     | 128/250 [5:47:59<5:16:40, 155.75s/it]

epoch  127  loss  2.138697721809149  time epoch  153.99182224273682


 52%|█████▏    | 129/250 [5:50:33<5:13:02, 155.23s/it]

epoch  128  loss  2.0654341407120227  time epoch  154.01956582069397
epoch  129  loss  2.1716041795909398  time epoch  153.98828053474426


 52%|█████▏    | 130/250 [5:54:42<6:06:43, 183.36s/it]

eval mse  50.446433753724314  eval mae  120.2916550848894  eval ssim  0.8647249656657549  time=  94.96008825302124


 52%|█████▏    | 131/250 [5:57:16<5:46:10, 174.54s/it]

epoch  130  loss  2.0931872844696047  time epoch  153.9694857597351


 53%|█████▎    | 132/250 [5:59:50<5:31:16, 168.44s/it]

epoch  131  loss  2.163465585559608  time epoch  154.2078628540039


 53%|█████▎    | 133/250 [6:02:24<5:19:59, 164.09s/it]

epoch  132  loss  2.087843815237284  time epoch  153.94740104675293


 54%|█████▎    | 134/250 [6:04:58<5:11:17, 161.02s/it]

epoch  133  loss  2.1091270700097082  time epoch  153.83219003677368


 54%|█████▍    | 135/250 [6:07:32<5:04:35, 158.92s/it]

epoch  134  loss  2.070318482816219  time epoch  154.02658486366272


 54%|█████▍    | 136/250 [6:10:06<4:59:09, 157.45s/it]

epoch  135  loss  2.082723536342383  time epoch  154.01657342910767


 55%|█████▍    | 137/250 [6:12:40<4:54:46, 156.52s/it]

epoch  136  loss  2.145199445635081  time epoch  154.34778833389282


 55%|█████▌    | 138/250 [6:15:14<4:50:53, 155.83s/it]

epoch  137  loss  2.179213987290858  time epoch  154.2249574661255


 56%|█████▌    | 139/250 [6:17:49<4:47:32, 155.43s/it]

epoch  138  loss  2.107475806027651  time epoch  154.49677729606628
epoch  139  loss  2.066411826014519  time epoch  154.53657865524292


 56%|█████▌    | 140/250 [6:22:00<5:37:27, 184.07s/it]

eval mse  38.61034381161829  eval mae  102.35342207987597  eval ssim  0.8920632847844076  time=  96.3032579421997


 56%|█████▋    | 141/250 [6:24:34<5:18:09, 175.13s/it]

epoch  140  loss  2.1596590526401993  time epoch  154.27261924743652


 57%|█████▋    | 142/250 [6:27:08<5:03:53, 168.82s/it]

epoch  141  loss  2.111867589503527  time epoch  154.1061909198761


 57%|█████▋    | 143/250 [6:29:42<4:53:13, 164.42s/it]

epoch  142  loss  2.1276450604200363  time epoch  154.15757513046265


 58%|█████▊    | 144/250 [6:32:16<4:45:02, 161.34s/it]

epoch  143  loss  2.2007657617330567  time epoch  154.1537425518036


 58%|█████▊    | 145/250 [6:34:50<4:38:29, 159.14s/it]

epoch  144  loss  2.102570689469577  time epoch  153.9881443977356


 58%|█████▊    | 146/250 [6:37:24<4:33:15, 157.65s/it]

epoch  145  loss  2.141336983442307  time epoch  154.1750922203064


 59%|█████▉    | 147/250 [6:39:58<4:28:45, 156.56s/it]

epoch  146  loss  2.040137992799283  time epoch  154.02453780174255


 59%|█████▉    | 148/250 [6:42:33<4:24:51, 155.80s/it]

epoch  147  loss  2.1735703065991396  time epoch  154.03092622756958


 60%|█████▉    | 149/250 [6:45:07<4:21:29, 155.34s/it]

epoch  148  loss  2.1003769844770424  time epoch  154.27218008041382
epoch  149  loss  2.133081346005201  time epoch  154.7132625579834


 60%|██████    | 150/250 [6:49:17<5:06:26, 183.87s/it]

eval mse  39.39964982050999  eval mae  102.36592592858965  eval ssim  0.8896546135497416  time=  95.6609115600586


 60%|██████    | 151/250 [6:51:52<4:48:56, 175.11s/it]

epoch  150  loss  2.0853952147066592  time epoch  154.68609142303467


 61%|██████    | 152/250 [6:54:26<4:35:57, 168.96s/it]

epoch  151  loss  2.0898776538670054  time epoch  154.59431624412537


 61%|██████    | 153/250 [6:57:01<4:26:09, 164.64s/it]

epoch  152  loss  2.1406865142285834  time epoch  154.552583694458


 62%|██████▏   | 154/250 [6:59:36<4:18:31, 161.58s/it]

epoch  153  loss  2.0604467563331132  time epoch  154.45057654380798


 62%|██████▏   | 155/250 [7:02:10<4:12:27, 159.44s/it]

epoch  154  loss  2.1096007898449907  time epoch  154.452241897583


 62%|██████▏   | 156/250 [7:04:45<4:07:31, 157.99s/it]

epoch  155  loss  2.038404507189989  time epoch  154.59691429138184


 63%|██████▎   | 157/250 [7:07:19<4:03:20, 156.99s/it]

epoch  156  loss  2.1020801141858105  time epoch  154.6537094116211


 63%|██████▎   | 158/250 [7:09:54<3:59:45, 156.36s/it]

epoch  157  loss  2.1112555764615535  time epoch  154.90605545043945


 64%|██████▎   | 159/250 [7:12:29<3:56:24, 155.87s/it]

epoch  158  loss  2.141370987892151  time epoch  154.71896982192993
epoch  159  loss  2.1256612300872813  time epoch  154.50144839286804


 64%|██████▍   | 160/250 [7:16:38<4:35:48, 183.87s/it]

eval mse  41.267341322200316  eval mae  108.95504726725778  eval ssim  0.8830950686171498  time=  94.64638423919678


 64%|██████▍   | 161/250 [7:19:12<4:19:36, 175.02s/it]

epoch  160  loss  2.0918552994728086  time epoch  154.35785746574402


 65%|██████▍   | 162/250 [7:21:47<4:07:42, 168.89s/it]

epoch  161  loss  2.109695971757174  time epoch  154.58362221717834


 65%|██████▌   | 163/250 [7:24:22<3:58:39, 164.59s/it]

epoch  162  loss  2.143046727031468  time epoch  154.5562334060669


 66%|██████▌   | 164/250 [7:26:56<3:51:34, 161.57s/it]

epoch  163  loss  2.150498788803816  time epoch  154.5225441455841


 66%|██████▌   | 165/250 [7:29:31<3:45:52, 159.44s/it]

epoch  164  loss  2.1511982284486297  time epoch  154.46495723724365


 66%|██████▋   | 166/250 [7:32:05<3:41:04, 157.92s/it]

epoch  165  loss  2.1163633398711688  time epoch  154.36091113090515


 67%|██████▋   | 167/250 [7:34:39<3:37:00, 156.88s/it]

epoch  166  loss  2.141650822013617  time epoch  154.4518280029297


 67%|██████▋   | 168/250 [7:37:14<3:33:23, 156.14s/it]

epoch  167  loss  2.0685425721108914  time epoch  154.41742944717407


 68%|██████▊   | 169/250 [7:39:48<3:30:06, 155.64s/it]

epoch  168  loss  2.0943774670362463  time epoch  154.46117568016052
epoch  169  loss  2.0530920185148713  time epoch  154.41216373443604


 68%|██████▊   | 170/250 [7:43:58<4:05:02, 183.78s/it]

eval mse  41.37090296654185  eval mae  108.81232763399744  eval ssim  0.8812868842520348  time=  94.99132919311523


 68%|██████▊   | 171/250 [7:46:32<3:50:21, 174.96s/it]

epoch  170  loss  1.988031814992427  time epoch  154.36853528022766


 69%|██████▉   | 172/250 [7:49:06<3:39:23, 168.77s/it]

epoch  171  loss  1.9211711049079896  time epoch  154.31499886512756


 69%|██████▉   | 173/250 [7:51:41<3:31:01, 164.43s/it]

epoch  172  loss  1.9470125086605556  time epoch  154.32331204414368


 70%|██████▉   | 174/250 [7:54:15<3:24:35, 161.52s/it]

epoch  173  loss  1.9583395659923568  time epoch  154.71782159805298


 70%|███████   | 175/250 [7:56:50<3:19:12, 159.37s/it]

epoch  174  loss  1.9064330652356154  time epoch  154.35654664039612


 70%|███████   | 176/250 [7:59:24<3:14:41, 157.86s/it]

epoch  175  loss  1.9139512412250037  time epoch  154.32535219192505


 71%|███████   | 177/250 [8:01:58<3:10:45, 156.79s/it]

epoch  176  loss  1.9173251055181035  time epoch  154.29368925094604


 71%|███████   | 178/250 [8:04:33<3:07:16, 156.06s/it]

epoch  177  loss  1.9386085540056224  time epoch  154.3647403717041


 72%|███████▏  | 179/250 [8:07:07<3:04:03, 155.54s/it]

epoch  178  loss  1.9196492068469515  time epoch  154.32958340644836
epoch  179  loss  1.9204412668943402  time epoch  154.43599796295166


 72%|███████▏  | 180/250 [8:11:17<3:34:19, 183.71s/it]

eval mse  31.708701054761363  eval mae  87.27929323038478  eval ssim  0.9151467767113173  time=  94.92578196525574


 72%|███████▏  | 181/250 [8:13:51<3:21:08, 174.90s/it]

epoch  180  loss  1.9650989159941679  time epoch  154.36250042915344


 73%|███████▎  | 182/250 [8:16:25<3:11:14, 168.74s/it]

epoch  181  loss  1.9364979006350045  time epoch  154.3450300693512


 73%|███████▎  | 183/250 [8:19:00<3:03:36, 164.42s/it]

epoch  182  loss  1.893612316250801  time epoch  154.35056829452515


 74%|███████▎  | 184/250 [8:21:34<2:57:33, 161.42s/it]

epoch  183  loss  1.953892759233713  time epoch  154.4156653881073


 74%|███████▍  | 185/250 [8:24:08<2:52:34, 159.30s/it]

epoch  184  loss  1.9343025431036953  time epoch  154.36379837989807


 74%|███████▍  | 186/250 [8:26:43<2:48:20, 157.83s/it]

epoch  185  loss  1.9682367824018001  time epoch  154.37574005126953


 75%|███████▍  | 187/250 [8:29:17<2:44:36, 156.77s/it]

epoch  186  loss  1.9116160497069359  time epoch  154.31216955184937


 75%|███████▌  | 188/250 [8:31:52<2:41:17, 156.08s/it]

epoch  187  loss  1.8856766968965524  time epoch  154.47916531562805


 76%|███████▌  | 189/250 [8:34:26<2:38:09, 155.56s/it]

epoch  188  loss  1.921452408283948  time epoch  154.3278923034668
epoch  189  loss  1.9569699563086018  time epoch  154.28751420974731


 76%|███████▌  | 190/250 [8:38:35<3:03:36, 183.61s/it]

eval mse  31.390560235187507  eval mae  85.81720349743108  eval ssim  0.9169562330922768  time=  94.73160576820374


 76%|███████▋  | 191/250 [8:41:09<2:51:53, 174.81s/it]

epoch  190  loss  1.958449870347977  time epoch  154.2738311290741


 77%|███████▋  | 192/250 [8:43:43<2:43:01, 168.65s/it]

epoch  191  loss  1.9184302598237988  time epoch  154.28635430335999


 77%|███████▋  | 193/250 [8:46:18<2:36:08, 164.36s/it]

epoch  192  loss  1.9057676240801806  time epoch  154.32637190818787


 78%|███████▊  | 194/250 [8:48:52<2:30:37, 161.38s/it]

epoch  193  loss  1.943492625653744  time epoch  154.4401924610138


 78%|███████▊  | 195/250 [8:51:27<2:26:03, 159.33s/it]

epoch  194  loss  1.9507369369268424  time epoch  154.53117895126343


 78%|███████▊  | 196/250 [8:54:01<2:22:03, 157.83s/it]

epoch  195  loss  1.9728839278221142  time epoch  154.35034728050232


 79%|███████▉  | 197/250 [8:56:36<2:18:30, 156.81s/it]

epoch  196  loss  1.8735969595611088  time epoch  154.4013547897339


 79%|███████▉  | 198/250 [8:59:10<2:15:15, 156.07s/it]

epoch  197  loss  1.969385667890309  time epoch  154.3679096698761


 80%|███████▉  | 199/250 [9:01:44<2:12:16, 155.62s/it]

epoch  198  loss  1.9241884388029582  time epoch  154.54832863807678
epoch  199  loss  1.9989986553788195  time epoch  154.38984322547913


 80%|████████  | 200/250 [9:05:54<2:33:15, 183.92s/it]

eval mse  31.05320936555316  eval mae  86.62390724716673  eval ssim  0.9163682378539071  time=  95.51671719551086


 80%|████████  | 201/250 [9:08:28<2:22:52, 174.95s/it]

epoch  200  loss  1.9346434354782107  time epoch  154.03267216682434


 81%|████████  | 202/250 [9:11:02<2:14:55, 168.65s/it]

epoch  201  loss  1.9307224363088604  time epoch  153.94129753112793


 81%|████████  | 203/250 [9:13:36<2:08:38, 164.23s/it]

epoch  202  loss  1.9585870854556555  time epoch  153.9076042175293


 82%|████████▏ | 204/250 [9:16:10<2:03:34, 161.19s/it]

epoch  203  loss  1.928875683248042  time epoch  154.10266637802124


 82%|████████▏ | 205/250 [9:18:45<1:59:18, 159.08s/it]

epoch  204  loss  1.9538686193525787  time epoch  154.15338850021362


 82%|████████▏ | 206/250 [9:21:19<1:55:33, 157.59s/it]

epoch  205  loss  1.9637460708618164  time epoch  154.10944366455078


 83%|████████▎ | 207/250 [9:23:53<1:52:11, 156.55s/it]

epoch  206  loss  1.9455708362162119  time epoch  154.12586212158203


 83%|████████▎ | 208/250 [9:26:27<1:49:03, 155.81s/it]

epoch  207  loss  1.9506938353180878  time epoch  154.06823444366455


 84%|████████▎ | 209/250 [9:29:01<1:46:05, 155.25s/it]

epoch  208  loss  1.9678290717303748  time epoch  153.96669840812683
epoch  209  loss  1.9097732655704025  time epoch  154.0475549697876


 84%|████████▍ | 210/250 [9:33:10<2:02:18, 183.45s/it]

eval mse  30.832858553357944  eval mae  84.57614281526796  eval ssim  0.9190497875100948  time=  95.1401994228363


 84%|████████▍ | 211/250 [9:35:44<1:53:29, 174.60s/it]

epoch  210  loss  1.9616649590432633  time epoch  153.9373209476471


 85%|████████▍ | 212/250 [9:38:18<1:46:43, 168.50s/it]

epoch  211  loss  1.9596598617732517  time epoch  154.27547192573547


 85%|████████▌ | 213/250 [9:40:52<1:41:11, 164.09s/it]

epoch  212  loss  1.9383750252425678  time epoch  153.80190205574036


 86%|████████▌ | 214/250 [9:43:26<1:36:36, 161.02s/it]

epoch  213  loss  2.0079119868576516  time epoch  153.86043334007263


 86%|████████▌ | 215/250 [9:46:00<1:32:40, 158.87s/it]

epoch  214  loss  1.9372577056288722  time epoch  153.85619711875916


 86%|████████▋ | 216/250 [9:48:34<1:29:09, 157.35s/it]

epoch  215  loss  1.9700480923056607  time epoch  153.8035798072815


 87%|████████▋ | 217/250 [9:51:08<1:26:00, 156.36s/it]

epoch  216  loss  1.971555416285991  time epoch  154.05547952651978


 87%|████████▋ | 218/250 [9:53:42<1:23:00, 155.65s/it]

epoch  217  loss  1.975868988782167  time epoch  153.96936440467834


 88%|████████▊ | 219/250 [9:56:16<1:20:10, 155.18s/it]

epoch  218  loss  1.9413552924990654  time epoch  154.10547828674316
epoch  219  loss  1.9271832965314388  time epoch  154.5313844680786


 88%|████████▊ | 220/250 [10:00:27<1:31:56, 183.89s/it]

eval mse  30.964607506041315  eval mae  84.93439818947178  eval ssim  0.9185955722800405  time=  96.28387975692749


 88%|████████▊ | 221/250 [10:03:01<1:24:37, 175.09s/it]

epoch  220  loss  1.9701931625604632  time epoch  154.55300521850586


 89%|████████▉ | 222/250 [10:05:36<1:18:50, 168.94s/it]

epoch  221  loss  1.9335451334714897  time epoch  154.59240627288818


 89%|████████▉ | 223/250 [10:08:10<1:14:05, 164.65s/it]

epoch  222  loss  1.9939743235707275  time epoch  154.63725972175598


 90%|████████▉ | 224/250 [10:10:45<1:10:02, 161.62s/it]

epoch  223  loss  1.918522734194994  time epoch  154.54618310928345


 90%|█████████ | 225/250 [10:13:19<1:06:26, 159.48s/it]

epoch  224  loss  1.9529297530651095  time epoch  154.47470045089722


 90%|█████████ | 226/250 [10:15:54<1:03:11, 158.00s/it]

epoch  225  loss  1.9697623141109935  time epoch  154.54150533676147


 91%|█████████ | 227/250 [10:18:29<1:00:10, 156.97s/it]

epoch  226  loss  1.9640441536903386  time epoch  154.586993932724


 91%|█████████ | 228/250 [10:21:03<57:17, 156.27s/it]  

epoch  227  loss  1.9330191530287257  time epoch  154.61737418174744


 92%|█████████▏| 229/250 [10:23:37<54:29, 155.67s/it]

epoch  228  loss  2.030004294216633  time epoch  154.28532314300537
epoch  229  loss  1.9867443501949305  time epoch  154.38988327980042


 92%|█████████▏| 230/250 [10:27:47<1:01:18, 183.94s/it]

eval mse  30.463510452562076  eval mae  83.77523998090416  eval ssim  0.9198544173331721  time=  95.46284246444702


 92%|█████████▏| 231/250 [10:30:22<55:26, 175.09s/it]  

epoch  230  loss  2.0030082218348975  time epoch  154.41656398773193


 93%|█████████▎| 232/250 [10:32:57<50:41, 168.98s/it]

epoch  231  loss  1.9360561780631538  time epoch  154.7410967350006


 93%|█████████▎| 233/250 [10:35:31<46:39, 164.66s/it]

epoch  232  loss  2.0174572341144095  time epoch  154.58556461334229


 94%|█████████▎| 234/250 [10:38:06<43:06, 161.66s/it]

epoch  233  loss  1.9701077640056617  time epoch  154.6326868534088


 94%|█████████▍| 235/250 [10:40:40<39:52, 159.52s/it]

epoch  234  loss  1.9836052112281328  time epoch  154.5269320011139


 94%|█████████▍| 236/250 [10:43:15<36:52, 158.01s/it]

epoch  235  loss  1.9312767833471298  time epoch  154.49316215515137


 95%|█████████▍| 237/250 [10:45:49<33:59, 156.92s/it]

epoch  236  loss  1.9897718317806719  time epoch  154.3740291595459


 95%|█████████▌| 238/250 [10:48:24<31:13, 156.16s/it]

epoch  237  loss  1.9820616662502286  time epoch  154.38994789123535


 96%|█████████▌| 239/250 [10:50:58<28:32, 155.71s/it]

epoch  238  loss  2.0226260997354983  time epoch  154.64443016052246
epoch  239  loss  1.989303844422102  time epoch  154.56155848503113


 96%|█████████▌| 240/250 [10:55:09<30:42, 184.21s/it]

eval mse  30.407143173703723  eval mae  84.51846843160641  eval ssim  0.9196177899176389  time=  96.11152076721191


 96%|█████████▋| 241/250 [10:57:44<26:17, 175.33s/it]

epoch  240  loss  1.986802721768618  time epoch  154.58673667907715


 97%|█████████▋| 242/250 [11:00:18<22:32, 169.11s/it]

epoch  241  loss  1.9867949962615963  time epoch  154.59759306907654


 97%|█████████▋| 243/250 [11:02:53<19:12, 164.69s/it]

epoch  242  loss  1.990670458227395  time epoch  154.3730070590973


 98%|█████████▊| 244/250 [11:05:27<16:09, 161.63s/it]

epoch  243  loss  1.996066175401211  time epoch  154.49831128120422


 98%|█████████▊| 245/250 [11:08:02<13:17, 159.50s/it]

epoch  244  loss  2.0135582290589804  time epoch  154.526948928833


 98%|█████████▊| 246/250 [11:10:36<10:32, 158.04s/it]

epoch  245  loss  2.0403226971626274  time epoch  154.63276886940002


 99%|█████████▉| 247/250 [11:13:11<07:50, 156.99s/it]

epoch  246  loss  2.030952607095241  time epoch  154.54193472862244


 99%|█████████▉| 248/250 [11:15:45<05:12, 156.23s/it]

epoch  247  loss  1.995541444420816  time epoch  154.44989585876465


100%|█████████▉| 249/250 [11:18:20<02:35, 155.68s/it]

epoch  248  loss  1.998902521282434  time epoch  154.3811595439911
epoch  249  loss  2.0041833259165296  time epoch  154.4307563304901


100%|██████████| 250/250 [11:22:30<00:00, 163.80s/it]

eval mse  30.32381623869489  eval mae  84.27407890368418  eval ssim  0.9196957440875437  time=  95.49011659622192



