In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import random
import time
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
import torch.utils.data as data

import argparse
import os
import gzip
import numpy as np
##############################################################################################################
def load_mnist(root='/kaggle/mnist-dataset/data'):
    """
    Load the MNIST dataset for generating training data.
    Adapted to work in Kaggle notebooks.
    """
    path = os.path.join('', '/kaggle/input/mnist-dataset/data/train-images-idx3-ubyte/train-images.idx3-ubyte')
    if not os.path.exists(path):
        raise FileNotFoundError(f"File not found: {path}. Ensure the MNIST dataset is uploaded or available in {root}.")
    
    with open(path, 'rb') as f:
        mnist = np.frombuffer(f.read(), np.uint8, offset=16)
        mnist = mnist.reshape(-1, 28, 28)
    return mnist


def load_fixed_set(root, is_train):
    # Load the fixed dataset
    filename = 'mnist_test_seq.npy'
    path = os.path.join('', "/kaggle/input/mnist-dataset/data/mnist_test_seq.npy")
    dataset = np.load(path)
    dataset = dataset[..., np.newaxis]
    return dataset


class MovingMNIST(data.Dataset):
    def __init__(self, root, is_train=True, n_frames_input=10, n_frames_output=10, num_objects=[2],
                 transform=None):
        '''
        param num_objects: a list of number of possible objects.
        '''
        super(MovingMNIST, self).__init__()

        self.dataset = None
        if is_train:
            self.mnist = load_mnist(root)
        else:
            if num_objects[0] != 2:
                self.mnist = load_mnist(root)
            else:
                self.dataset = load_fixed_set(root, False)
        self.length = int(1e4) if self.dataset is None else self.dataset.shape[1]

        self.is_train = is_train
        self.num_objects = num_objects
        self.n_frames_input = n_frames_input
        self.n_frames_output = n_frames_output
        self.n_frames_total = self.n_frames_input + self.n_frames_output
        self.transform = transform
        # For generating data
        self.image_size_ = 64
        self.digit_size_ = 28
        self.step_length_ = 0.1

    def get_random_trajectory(self, seq_length):
        ''' Generate a random sequence of a MNIST digit '''
        canvas_size = self.image_size_ - self.digit_size_
        x = random.random()
        y = random.random()
        theta = random.random() * 2 * np.pi
        v_y = np.sin(theta)
        v_x = np.cos(theta)

        start_y = np.zeros(seq_length)
        start_x = np.zeros(seq_length)
        for i in range(seq_length):
            # Take a step along velocity.
            y += v_y * self.step_length_
            x += v_x * self.step_length_

            # Bounce off edges.
            if x <= 0:
                x = 0
                v_x = -v_x
            if x >= 1.0:
                x = 1.0
                v_x = -v_x
            if y <= 0:
                y = 0
                v_y = -v_y
            if y >= 1.0:
                y = 1.0
                v_y = -v_y
            start_y[i] = y
            start_x[i] = x

        # Scale to the size of the canvas.
        start_y = (canvas_size * start_y).astype(np.int32)
        start_x = (canvas_size * start_x).astype(np.int32)
        return start_y, start_x

    def generate_moving_mnist(self, num_digits=2):
        '''
        Get random trajectories for the digits and generate a video.
        '''
        data = np.zeros((self.n_frames_total, self.image_size_, self.image_size_), dtype=np.float32)
        for n in range(num_digits):
            # Trajectory
            start_y, start_x = self.get_random_trajectory(self.n_frames_total)
            ind = random.randint(0, self.mnist.shape[0] - 1)
            digit_image = self.mnist[ind]
            for i in range(self.n_frames_total):
                top = start_y[i]
                left = start_x[i]
                bottom = top + self.digit_size_
                right = left + self.digit_size_
                # Draw digit
                data[i, top:bottom, left:right] = np.maximum(data[i, top:bottom, left:right], digit_image)

        data = data[..., np.newaxis]
        return data

    def __getitem__(self, idx):
        length = self.n_frames_input + self.n_frames_output
        if self.is_train or self.num_objects[0] != 2:
            # Sample number of objects
            num_digits = random.choice(self.num_objects)
            # Generate data on the fly
            images = self.generate_moving_mnist(num_digits)
        else:
            images = self.dataset[:, idx, ...]

        # if self.transform is not None:
        #     images = self.transform(images)

        r = 1 # patch size (a 4 dans les PredRNN)
        w = int(64 / r)
        images = images.reshape((length, w, r, w, r)).transpose(0, 2, 4, 1, 3).reshape((length, r * r, w, w))

        input = images[:self.n_frames_input]
        if self.n_frames_output > 0:
            output = images[self.n_frames_input:length]
        else:
            output = []

        frozen = input[-1]
        # add a wall to input data
        # pad = np.zeros_like(input[:, 0])
        # pad[:, 0] = 1
        # pad[:, pad.shape[1] - 1] = 1
        # pad[:, :, 0] = 1
        # pad[:, :, pad.shape[2] - 1] = 1
        #
        # input = np.concatenate((input, np.expand_dims(pad, 1)), 1)

        output = torch.from_numpy(output / 255.0).contiguous().float()
        input = torch.from_numpy(input / 255.0).contiguous().float()
        # print()
        # print(input.size())
        # print(output.size())

        out = [idx,input,output]
        return out

    def __len__(self):
        return self.length

##############################################################################################################
    
from numpy import *
from numpy.linalg import *
from scipy.special import factorial
from functools import reduce
import torch
import torch.nn as nn

from functools import reduce

__all__ = ['M2K','K2M']

def _apply_axis_left_dot(x, mats):
    assert x.dim() == len(mats)+1
    sizex = x.size()
    k = x.dim()-1
    for i in range(k):
        x = tensordot(mats[k-i-1], x, dim=[1,k])
    x = x.permute([k,]+list(range(k))).contiguous()
    x = x.view(sizex)
    return x

def _apply_axis_right_dot(x, mats):
    assert x.dim() == len(mats)+1
    sizex = x.size()
    k = x.dim()-1
    x = x.permute(list(range(1,k+1))+[0,])
    for i in range(k):
        x = tensordot(x, mats[i], dim=[0,0])
    x = x.contiguous()
    x = x.view(sizex)
    return x

class _MK(nn.Module):
    def __init__(self, shape):
        super(_MK, self).__init__()
        self._size = torch.Size(shape)
        self._dim = len(shape)
        M = []
        invM = []
        assert len(shape) > 0
        j = 0
        for l in shape:
            M.append(zeros((l,l)))
            for i in range(l):
                M[-1][i] = ((arange(l)-(l-1)//2)**i)/factorial(i)
            invM.append(inv(M[-1]))
            self.register_buffer('_M'+str(j), torch.from_numpy(M[-1]))
            self.register_buffer('_invM'+str(j), torch.from_numpy(invM[-1]))
            j += 1

    @property
    def M(self):
        return list(self._buffers['_M'+str(j)] for j in range(self.dim()))
    @property
    def invM(self):
        return list(self._buffers['_invM'+str(j)] for j in range(self.dim()))

    def size(self):
        return self._size
    def dim(self):
        return self._dim
    def _packdim(self, x):
        assert x.dim() >= self.dim()
        if x.dim() == self.dim():
            x = x[newaxis,:]
        x = x.contiguous()
        x = x.view([-1,]+list(x.size()[-self.dim():]))
        return x

    def forward(self):
        pass

class M2K(_MK):
    """
    convert moment matrix to convolution kernel
    Arguments:
        shape (tuple of int): kernel shape
    Usage:
        m2k = M2K([5,5])
        m = torch.randn(5,5,dtype=torch.float64)
        k = m2k(m)
    """
    def __init__(self, shape):
        super(M2K, self).__init__(shape)
    def forward(self, m):
        """
        m (Tensor): torch.size=[...,*self.shape]
        """
        sizem = m.size()
        m = self._packdim(m)
        m = _apply_axis_left_dot(m, self.invM)
        m = m.view(sizem)
        return m

class K2M(_MK):
    """
    convert convolution kernel to moment matrix
    Arguments:
        shape (tuple of int): kernel shape
    Usage:
        k2m = K2M([5,5])
        k = torch.randn(5,5,dtype=torch.float64)
        m = k2m(k)
    """
    def __init__(self, shape):
        super(K2M, self).__init__(shape)
    def forward(self, k):
        """
        k (Tensor): torch.size=[...,*self.shape]
        """
        sizek = k.size()
        k = self._packdim(k)
        k = _apply_axis_left_dot(k, self.M)
        k = k.view(sizek)
        return k


    
def tensordot(a,b,dim):
    """
    tensordot in PyTorch, see numpy.tensordot?
    """
    l = lambda x,y:x*y
    if isinstance(dim,int):
        a = a.contiguous()
        b = b.contiguous()
        sizea = a.size()
        sizeb = b.size()
        sizea0 = sizea[:-dim]
        sizea1 = sizea[-dim:]
        sizeb0 = sizeb[:dim]
        sizeb1 = sizeb[dim:]
        N = reduce(l, sizea1, 1)
        assert reduce(l, sizeb0, 1) == N
    else:
        adims = dim[0]
        bdims = dim[1]
        adims = [adims,] if isinstance(adims, int) else adims
        bdims = [bdims,] if isinstance(bdims, int) else bdims
        adims_ = set(range(a.dim())).difference(set(adims))
        adims_ = list(adims_)
        adims_.sort()
        perma = adims_+adims
        bdims_ = set(range(b.dim())).difference(set(bdims))
        bdims_ = list(bdims_)
        bdims_.sort()
        permb = bdims+bdims_
        a = a.permute(*perma).contiguous()
        b = b.permute(*permb).contiguous()

        sizea = a.size()
        sizeb = b.size()
        sizea0 = sizea[:-len(adims)]
        sizea1 = sizea[-len(adims):]
        sizeb0 = sizeb[:len(bdims)]
        sizeb1 = sizeb[len(bdims):]
        N = reduce(l, sizea1, 1)
        assert reduce(l, sizeb0, 1) == N
    a = a.view([-1,N])
    b = b.view([N,-1])
    c = a@b
    return c.view(sizea0+sizeb1)

##############################################################################################################
import torch
import torch.nn as nn

class PhyCell_Cell(nn.Module):
    def __init__(self, input_dim, F_hidden_dim, kernel_size, bias=1):
        super(PhyCell_Cell, self).__init__()
        self.input_dim  = input_dim
        self.F_hidden_dim = F_hidden_dim
        self.kernel_size = kernel_size
        self.padding     = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias
        
        self.F = nn.Sequential()
        self.F.add_module('conv1', nn.Conv2d(in_channels=input_dim, out_channels=F_hidden_dim, kernel_size=self.kernel_size, stride=(1,1), padding=self.padding))
        self.F.add_module('bn1',nn.GroupNorm( 7 ,F_hidden_dim))        
        self.F.add_module('conv2', nn.Conv2d(in_channels=F_hidden_dim, out_channels=input_dim, kernel_size=(1,1), stride=(1,1), padding=(0,0)))

        self.convgate = nn.Conv2d(in_channels=self.input_dim + self.input_dim,
                              out_channels= self.input_dim,
                              kernel_size=(3,3),
                              padding=(1,1), bias=self.bias)

    def forward(self, x, hidden): # x [batch_size, hidden_dim, height, width]      
        combined = torch.cat([x, hidden], dim=1)  # concatenate along channel axis
        combined_conv = self.convgate(combined)
        K = torch.sigmoid(combined_conv)
        hidden_tilde = hidden + self.F(hidden)        # prediction
        next_hidden = hidden_tilde + K * (x-hidden_tilde)   # correction , Haddamard product     
        return next_hidden

class PhyCell(nn.Module):
    def __init__(self, input_shape, input_dim, F_hidden_dims, n_layers, kernel_size, device):
        super(PhyCell, self).__init__()
        self.input_shape = input_shape
        self.input_dim = input_dim
        self.F_hidden_dims = F_hidden_dims
        self.n_layers = n_layers
        self.kernel_size = kernel_size
        self.H = []  
        self.device = device
             
        cell_list = []
        for i in range(0, self.n_layers):
            cell_list.append(PhyCell_Cell(input_dim=input_dim,
                                          F_hidden_dim=self.F_hidden_dims[i],
                                          kernel_size=self.kernel_size))                                     
        self.cell_list = nn.ModuleList(cell_list)
        
       
    def forward(self, input_, first_timestep=False): # input_ [batch_size, 1, channels, width, height]    
        batch_size = input_.data.size()[0]
        if (first_timestep):   
            self.initHidden(batch_size) # init Hidden at each forward start
              
        for j,cell in enumerate(self.cell_list):
            if j==0: # bottom layer
                self.H[j] = cell(input_, self.H[j])
            else:
                self.H[j] = cell(self.H[j-1],self.H[j])
        
        return self.H , self.H 
    
    def initHidden(self,batch_size):
        self.H = [] 
        for i in range(self.n_layers):
            self.H.append( torch.zeros(batch_size, self.input_dim, self.input_shape[0], self.input_shape[1]).to(self.device) )

    def setHidden(self, H):
        self.H = H

        
class ConvLSTM_Cell(nn.Module):
    def __init__(self, input_shape, input_dim, hidden_dim, kernel_size, bias=1):              
        """
        input_shape: (int, int)
            Height and width of input tensor as (height, width).
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """
        super(ConvLSTM_Cell, self).__init__()
        
        self.height, self.width = input_shape
        self.input_dim  = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding     = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias        = bias
        
        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding, bias=self.bias)
                 
    # we implement LSTM that process only one timestep 
    def forward(self,x, hidden): # x [batch, hidden_dim, width, height]          
        h_cur, c_cur = hidden
        
        combined = torch.cat([x, h_cur], dim=1)  # concatenate along channel axis
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next


class ConvLSTM(nn.Module):
    def __init__(self, input_shape, input_dim, hidden_dims, n_layers, kernel_size,device):
        super(ConvLSTM, self).__init__()
        self.input_shape = input_shape
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.n_layers = n_layers
        self.kernel_size = kernel_size
        self.H, self.C = [],[]   
        self.device = device
        
        cell_list = []
        for i in range(0, self.n_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dims[i-1]
            print('layer ',i,'input dim ', cur_input_dim, ' hidden dim ', self.hidden_dims[i])
            cell_list.append(ConvLSTM_Cell(input_shape=self.input_shape,
                                          input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dims[i],
                                          kernel_size=self.kernel_size))                                     
        self.cell_list = nn.ModuleList(cell_list)
        
       
    def forward(self, input_, first_timestep=False): # input_ [batch_size, 1, channels, width, height]    
        batch_size = input_.data.size()[0]
        if (first_timestep):   
            self.initHidden(batch_size) # init Hidden at each forward start
              
        for j,cell in enumerate(self.cell_list):
            if j==0: # bottom layer
                self.H[j], self.C[j] = cell(input_, (self.H[j],self.C[j]))
            else:
                self.H[j], self.C[j] = cell(self.H[j-1],(self.H[j],self.C[j]))
        
        return (self.H,self.C) , self.H   # (hidden, output)
    
    def initHidden(self,batch_size):
        self.H, self.C = [],[]  
        for i in range(self.n_layers):
            self.H.append( torch.zeros(batch_size,self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device) )
            self.C.append( torch.zeros(batch_size,self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device) )
    
    def setHidden(self, hidden):
        H,C = hidden
        self.H, self.C = H,C
 

class dcgan_conv(nn.Module):
    def __init__(self, nin, nout, stride):
        super(dcgan_conv, self).__init__()
        self.main = nn.Sequential(
                nn.Conv2d(in_channels=nin, out_channels=nout, kernel_size=(3,3), stride=stride, padding=1),
                nn.GroupNorm(16,nout),
                nn.LeakyReLU(0.2, inplace=True),
                )

    def forward(self, input):
        return self.main(input)

class dcgan_upconv(nn.Module):
    def __init__(self, nin, nout, stride):
        super(dcgan_upconv, self).__init__()
        if (stride ==2):
            output_padding = 1
        else:
            output_padding = 0
        self.main = nn.Sequential(
                nn.ConvTranspose2d(in_channels=nin,out_channels=nout,kernel_size=(3,3), stride=stride,padding=1,output_padding=output_padding),
                nn.GroupNorm(16,nout),
                nn.LeakyReLU(0.2, inplace=True),
                )

    def forward(self, input):
        return self.main(input)
        
class encoder_E(nn.Module):
    def __init__(self, nc=1, nf=32):
        super(encoder_E, self).__init__()
        # input is (1) x 64 x 64
        self.c1 = dcgan_conv(nc, nf, stride=2) # (32) x 32 x 32
        self.c2 = dcgan_conv(nf, nf, stride=1) # (32) x 32 x 32
        self.c3 = dcgan_conv(nf, 2*nf, stride=2) # (64) x 16 x 16

    def forward(self, input):
        h1 = self.c1(input)  
        h2 = self.c2(h1)    
        h3 = self.c3(h2)
        return h3

class decoder_D(nn.Module):
    def __init__(self, nc=1, nf=32):
        super(decoder_D, self).__init__()
        self.upc1 = dcgan_upconv(2*nf, nf, stride=2) #(32) x 32 x 32
        self.upc2 = dcgan_upconv(nf, nf, stride=1) #(32) x 32 x 32
        self.upc3 = nn.ConvTranspose2d(in_channels=nf,out_channels=nc,kernel_size=(3,3),stride=2,padding=1,output_padding=1)  #(nc) x 64 x 64

    def forward(self, input):      
        d1 = self.upc1(input) 
        d2 = self.upc2(d1)
        d3 = self.upc3(d2)  
        return d3  


class encoder_specific(nn.Module):
    def __init__(self, nc=64, nf=64):
        super(encoder_specific, self).__init__()
        self.c1 = dcgan_conv(nc, nf, stride=1) # (64) x 16 x 16
        self.c2 = dcgan_conv(nf, nf, stride=1) # (64) x 16 x 16

    def forward(self, input):
        h1 = self.c1(input)  
        h2 = self.c2(h1)     
        return h2

class decoder_specific(nn.Module):
    def __init__(self, nc=64, nf=64):
        super(decoder_specific, self).__init__()
        self.upc1 = dcgan_upconv(nf, nf, stride=1) #(64) x 16 x 16
        self.upc2 = dcgan_upconv(nf, nc, stride=1) #(32) x 32 x 32
        
    def forward(self, input):
        d1 = self.upc1(input) 
        d2 = self.upc2(d1)  
        return d2       

        
class EncoderRNN(torch.nn.Module):
    def __init__(self,phycell,convcell, device):
        super(EncoderRNN, self).__init__()
        self.encoder_E = encoder_E()   # general encoder 64x64x1 -> 32x32x32
        self.encoder_Ep = encoder_specific() # specific image encoder 32x32x32 -> 16x16x64
        self.encoder_Er = encoder_specific() 
        self.decoder_Dp = decoder_specific() # specific image decoder 16x16x64 -> 32x32x32 
        self.decoder_Dr = decoder_specific()     
        self.decoder_D = decoder_D()  # general decoder 32x32x32 -> 64x64x1 

        self.encoder_E = self.encoder_E.to(device)
        self.encoder_Ep = self.encoder_Ep.to(device) 
        self.encoder_Er = self.encoder_Er.to(device) 
        self.decoder_Dp = self.decoder_Dp.to(device) 
        self.decoder_Dr = self.decoder_Dr.to(device)               
        self.decoder_D = self.decoder_D.to(device)
        self.phycell = phycell.to(device)
        self.convcell = convcell.to(device)

    def forward(self, input, first_timestep=False, decoding=False):
        input = self.encoder_E(input) # general encoder 64x64x1 -> 32x32x32
    
        if decoding:  # input=None in decoding phase
            input_phys = None
        else:
            input_phys = self.encoder_Ep(input)
        input_conv = self.encoder_Er(input)     

        hidden1, output1 = self.phycell(input_phys, first_timestep)
        hidden2, output2 = self.convcell(input_conv, first_timestep)

        decoded_Dp = self.decoder_Dp(output1[-1])
        decoded_Dr = self.decoder_Dr(output2[-1])
        
        out_phys = torch.sigmoid(self.decoder_D(decoded_Dp)) # partial reconstructions for vizualization
        out_conv = torch.sigmoid(self.decoder_D(decoded_Dr))

        concat = decoded_Dp + decoded_Dr   
        output_image = torch.sigmoid( self.decoder_D(concat ))
        return out_phys, hidden1, output_image, out_phys, out_conv        

##############################################################################################################

device = torch.device("cuda")

args = {}
args['root'] = '/kaggle/input/mnist_dataset/data'

mm = MovingMNIST(root=args['root'], is_train=True, n_frames_input=10, n_frames_output=10, num_objects=[2])
train_loader = torch.utils.data.DataLoader(dataset=mm, batch_size=64, shuffle=True, num_workers=0)

mm = MovingMNIST(root=args['root'], is_train=False, n_frames_input=10, n_frames_output=10, num_objects=[2])
test_loader = torch.utils.data.DataLoader(dataset=mm, batch_size=64, shuffle=False, num_workers=0)

constraints = torch.zeros((49,7,7)).to(device)
ind = 0
for i in range(0,7):
    for j in range(0,7):
        constraints[ind,i,j] = 1
        ind +=1    

def train_on_batch(input_tensor, target_tensor, encoder, encoder_optimizer, criterion,teacher_forcing_ratio):                
    encoder_optimizer.zero_grad()
    # input_tensor : torch.Size([batch_size, input_length, channels, cols, rows])
    input_length  = input_tensor.size(1)
    target_length = target_tensor.size(1)
    loss = 0
    for ei in range(input_length-1): 
        encoder_output, encoder_hidden, output_image,_,_ = encoder(input_tensor[:,ei,:,:,:], (ei==0) )
        loss += criterion(output_image,input_tensor[:,ei+1,:,:,:])

    decoder_input = input_tensor[:,-1,:,:,:] # first decoder input = last image of input sequence
    
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False 
    for di in range(target_length):
        decoder_output, decoder_hidden, output_image,_,_ = encoder(decoder_input)
        target = target_tensor[:,di,:,:,:]
        loss += criterion(output_image,target)
        if use_teacher_forcing:
            decoder_input = target # Teacher forcing    
        else:
            decoder_input = output_image

    ## Moment regularization  # encoder.phycell.cell_list[0].F.conv1.weight # size (nb_filters,in_channels,7,7)
    #k2m = K2M([7,7]).to(device)
    #for b in range(0,encoder.phycell.cell_list[0].input_dim):
    #    filters = encoder.phycell.cell_list[0].F.conv1.weight[:,b,:,:] # (nb_filters,7,7)     
    #    m = k2m(filters.double()) 
    #    m  = m.float()   
    #    loss += criterion(m, constraints) # constrains is a precomputed matrix   
    loss.backward()
    encoder_optimizer.step()
    return loss.item() / target_length


def trainIters(encoder, nepochs, print_every=10,eval_every=10,name=''):
    train_losses = []
    best_mse = float('inf')

    encoder_optimizer = torch.optim.Adam(encoder.parameters(),lr=0.001)
    scheduler_enc = ReduceLROnPlateau(encoder_optimizer, mode='min', patience=2,factor=0.1,verbose=True)
    criterion = nn.MSELoss()
    
    for epoch in tqdm(range(0, nepochs)):        
        t0 = time.time()
        loss_epoch = 0
        teacher_forcing_ratio = np.maximum(0 , 1 - epoch * 0.003) 
        
        for i, out in enumerate(train_loader, 0):
            input_tensor = out[1].to(device)
            target_tensor = out[2].to(device)
            loss = train_on_batch(input_tensor, target_tensor, encoder, encoder_optimizer, criterion, teacher_forcing_ratio)                                   
            loss_epoch += loss
                      
        train_losses.append(loss_epoch)        
        if (epoch+1) % print_every == 0:
            print('epoch ',epoch,  ' loss ',loss_epoch, ' time epoch ',time.time()-t0)
            
        if (epoch+1) % eval_every == 0:
            mse, mae,ssim = evaluate(encoder,test_loader) 
            scheduler_enc.step(mse)                   
            torch.save(encoder.state_dict(),'encoder_{}.pth'.format(name))                           
    return train_losses

    
def evaluate(encoder,loader):
    total_mse, total_mae,total_ssim,total_bce = 0,0,0,0
    t0 = time.time()
    with torch.no_grad():
        for i, out in enumerate(loader, 0):
            input_tensor = out[1].to(device)
            target_tensor = out[2].to(device)
            input_length = input_tensor.size()[1]
            target_length = target_tensor.size()[1]

            for ei in range(input_length-1):
                encoder_output, encoder_hidden, _,_,_  = encoder(input_tensor[:,ei,:,:,:], (ei==0))

            decoder_input = input_tensor[:,-1,:,:,:] # first decoder input= last image of input sequence
            predictions = []

            for di in range(target_length):
                decoder_output, decoder_hidden, output_image,_,_ = encoder(decoder_input, False, False)
                decoder_input = output_image
                predictions.append(output_image.cpu())

            input = input_tensor.cpu().numpy()
            target = target_tensor.cpu().numpy()
            predictions =  np.stack(predictions) # (10, batch_size, 1, 64, 64)
            predictions = predictions.swapaxes(0,1)  # (batch_size,10, 1, 64, 64)

            mse_batch = np.mean((predictions-target)**2 , axis=(0,1,2)).sum()
            mae_batch = np.mean(np.abs(predictions-target) ,  axis=(0,1,2)).sum() 
            total_mse += mse_batch
            total_mae += mae_batch
            for a in range(0,target.shape[0]):
                for b in range(0,target.shape[1]):
                    total_ssim += ssim(target[a,b,0,], predictions[a,b,0,], data_range=1.0) / (target.shape[0]*target.shape[1]) 

            
            cross_entropy = -target*np.log(predictions) - (1-target) * np.log(1-predictions)
            cross_entropy = cross_entropy.sum()
            cross_entropy = cross_entropy / (64*target_length)
            total_bce +=  cross_entropy
     
    print('eval mse ', total_mse/len(loader),  ' eval mae ', total_mae/len(loader),' eval ssim ',total_ssim/len(loader), ' time= ', time.time()-t0)        
    return total_mse/len(loader),  total_mae/len(loader), total_ssim/len(loader)


# phycell  =  PhyCell(input_shape=(16,16), input_dim=64, F_hidden_dims=[49], n_layers=1, kernel_size=(7,7), device=device) 
convcell_0 =  ConvLSTM(input_shape=(16,16), input_dim=64, hidden_dims=[128,128,64], n_layers=3, kernel_size=(3,3), device=device)   
convcell =  ConvLSTM(input_shape=(16,16), input_dim=64, hidden_dims=[128,128,64], n_layers=3, kernel_size=(3,3), device=device)   
encoder  = EncoderRNN(convcell_0, convcell, device)
  
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
   
print('phycell ' , count_parameters(convcell_0) )    
print('convcell ' , count_parameters(convcell) ) 
print('encoder ' , count_parameters(encoder) ) 

trainIters(encoder,200,print_every=1,eval_every=10,name="phydnet_double")

#encoder.load_state_dict(torch.load('save/encoder_phydnet.pth'))
#encoder.eval()
#mse, mae,ssim = evaluate(encoder,test_loader) 
torch.save(encoder.state_dict(),'encoder_{}.pth'.format('phydnet_2_lstm_100'))


  return sum(p.numel() for p in model.parameters() if p.requires_grad)


layer  0 input dim  64  hidden dim  128
layer  1 input dim  128  hidden dim  128
layer  2 input dim  128  hidden dim  64
layer  0 input dim  64  hidden dim  128
layer  1 input dim  128  hidden dim  128
layer  2 input dim  128  hidden dim  64
phycell  2508032
convcell  2508032
encoder  5368961


  0%|          | 1/200 [02:36<8:38:49, 156.43s/it]

epoch  0  loss  11.35341339707375  time epoch  156.429993391037


  1%|          | 2/200 [05:11<8:33:44, 155.68s/it]

epoch  1  loss  5.4381491541862506  time epoch  155.1541042327881


  2%|▏         | 3/200 [07:46<8:29:53, 155.30s/it]

epoch  2  loss  4.440904514491558  time epoch  154.8427734375


  2%|▏         | 4/200 [10:21<8:26:50, 155.16s/it]

epoch  3  loss  3.60855732858181  time epoch  154.93702459335327


  2%|▎         | 5/200 [12:55<8:23:27, 154.91s/it]

epoch  4  loss  4.20542240291834  time epoch  154.46920585632324


  3%|▎         | 6/200 [15:30<8:20:31, 154.80s/it]

epoch  5  loss  3.3180296167731282  time epoch  154.59750485420227


  4%|▎         | 7/200 [18:04<8:17:22, 154.63s/it]

epoch  6  loss  3.1657181948423396  time epoch  154.25915360450745


  4%|▍         | 8/200 [20:39<8:14:31, 154.54s/it]

epoch  7  loss  2.9742852345108988  time epoch  154.34828877449036


  4%|▍         | 9/200 [23:13<8:11:42, 154.46s/it]

epoch  8  loss  3.1846848696470254  time epoch  154.29049706459045
epoch  9  loss  2.6932097733020774  time epoch  154.05561566352844


  5%|▌         | 10/200 [27:24<9:43:27, 184.25s/it]

eval mse  85.16803712298156  eval mae  164.54955957497762  eval ssim  0.6910831944182351  time=  96.86088585853577


  6%|▌         | 11/200 [29:58<9:11:28, 175.07s/it]

epoch  10  loss  3.012785425782206  time epoch  154.2537431716919


  6%|▌         | 12/200 [32:33<8:49:00, 168.83s/it]

epoch  11  loss  2.7979340404272075  time epoch  154.57280492782593


  6%|▋         | 13/200 [35:07<8:32:58, 164.59s/it]

epoch  12  loss  2.828372929990291  time epoch  154.8161883354187


  7%|▋         | 14/200 [37:42<8:20:55, 161.59s/it]

epoch  13  loss  2.6637313663959503  time epoch  154.64924144744873


  8%|▊         | 15/200 [40:17<8:11:47, 159.50s/it]

epoch  14  loss  2.8799298286438004  time epoch  154.6628212928772


  8%|▊         | 16/200 [42:51<8:04:40, 158.05s/it]

epoch  15  loss  2.5241360604763026  time epoch  154.672217130661


  8%|▊         | 17/200 [45:26<7:58:57, 157.04s/it]

epoch  16  loss  2.657910476624966  time epoch  154.6861481666565


  9%|▉         | 18/200 [48:01<7:54:09, 156.32s/it]

epoch  17  loss  2.4811175420880316  time epoch  154.63875937461853


 10%|▉         | 19/200 [50:35<7:50:01, 155.81s/it]

epoch  18  loss  2.544031302630901  time epoch  154.6327784061432
epoch  19  loss  2.6756771132349977  time epoch  154.56731629371643


 10%|█         | 20/200 [54:46<9:13:06, 184.37s/it]

eval mse  101.17265295526784  eval mae  173.46529665418493  eval ssim  0.7705854064136481  time=  96.30354833602905


 10%|█         | 21/200 [57:21<8:43:03, 175.33s/it]

epoch  20  loss  2.8055418491363517  time epoch  154.24146056175232


 11%|█         | 22/200 [59:55<8:21:56, 169.19s/it]

epoch  21  loss  2.493461099267005  time epoch  154.88298845291138


 12%|█▏        | 23/200 [1:02:30<8:06:21, 164.87s/it]

epoch  22  loss  2.4687889292836194  time epoch  154.78537392616272


 12%|█▏        | 24/200 [1:05:05<7:54:41, 161.83s/it]

epoch  23  loss  2.61089637875557  time epoch  154.73355746269226


 12%|█▎        | 25/200 [1:07:40<7:45:47, 159.70s/it]

epoch  24  loss  2.331495662033557  time epoch  154.72553730010986


 13%|█▎        | 26/200 [1:10:14<7:38:42, 158.17s/it]

epoch  25  loss  2.447394195199012  time epoch  154.6129114627838


 14%|█▎        | 27/200 [1:12:49<7:33:05, 157.14s/it]

epoch  26  loss  2.5406135633587845  time epoch  154.7264142036438


 14%|█▍        | 28/200 [1:15:24<7:28:22, 156.41s/it]

epoch  27  loss  2.5628750711679458  time epoch  154.70391416549683


 14%|█▍        | 29/200 [1:17:59<7:24:22, 155.92s/it]

epoch  28  loss  2.509886175394059  time epoch  154.7838990688324
epoch  29  loss  2.2862218528985983  time epoch  154.67048954963684


 15%|█▌        | 30/200 [1:22:09<8:42:15, 184.33s/it]

eval mse  80.67155184411699  eval mae  152.82952948892193  eval ssim  0.8173317472027936  time=  95.88171792030334


 16%|█▌        | 31/200 [1:24:43<8:13:40, 175.27s/it]

epoch  30  loss  2.4240524813532827  time epoch  154.1378893852234


 16%|█▌        | 32/200 [1:27:18<7:53:40, 169.17s/it]

epoch  31  loss  2.345496851205826  time epoch  154.93796610832214


 16%|█▋        | 33/200 [1:29:53<7:38:41, 164.80s/it]

epoch  32  loss  2.5145298093557353  time epoch  154.59347081184387


 17%|█▋        | 34/200 [1:32:27<7:27:28, 161.74s/it]

epoch  33  loss  2.441766200959682  time epoch  154.58937454223633


 18%|█▊        | 35/200 [1:35:02<7:18:51, 159.59s/it]

epoch  34  loss  2.3745816014707106  time epoch  154.57027649879456


 18%|█▊        | 36/200 [1:37:37<7:12:07, 158.10s/it]

epoch  35  loss  2.281346549093725  time epoch  154.61661052703857


 18%|█▊        | 37/200 [1:40:11<7:06:44, 157.08s/it]

epoch  36  loss  2.2872685968875888  time epoch  154.71660494804382


 19%|█▉        | 38/200 [1:42:46<7:02:12, 156.37s/it]

epoch  37  loss  2.45525826588273  time epoch  154.71703839302063


 20%|█▉        | 39/200 [1:45:21<6:58:10, 155.84s/it]

epoch  38  loss  2.167169623076916  time epoch  154.6063530445099
epoch  39  loss  2.2605649061501016  time epoch  154.65599298477173


 20%|██        | 40/200 [1:49:32<8:11:53, 184.46s/it]

eval mse  60.65981778673306  eval mae  129.2528016886134  eval ssim  0.8508013145331998  time=  96.51109600067139


 20%|██        | 41/200 [1:52:06<7:44:45, 175.38s/it]

epoch  40  loss  2.238712816685438  time epoch  154.2054305076599


 21%|██        | 42/200 [1:54:40<7:25:06, 169.03s/it]

epoch  41  loss  2.2352839559316626  time epoch  154.2104630470276


 22%|██▏       | 43/200 [1:57:15<7:10:45, 164.62s/it]

epoch  42  loss  2.3566626414656633  time epoch  154.3381118774414


 22%|██▏       | 44/200 [1:59:49<6:59:48, 161.47s/it]

epoch  43  loss  2.203799089789391  time epoch  154.09497928619385


 22%|██▎       | 45/200 [2:02:23<6:51:36, 159.33s/it]

epoch  44  loss  2.355782863497734  time epoch  154.35280466079712


 23%|██▎       | 46/200 [2:04:57<6:45:06, 157.84s/it]

epoch  45  loss  2.3212297864258287  time epoch  154.34754085540771


 24%|██▎       | 47/200 [2:07:32<6:39:44, 156.76s/it]

epoch  46  loss  2.342099750787018  time epoch  154.24824047088623


 24%|██▍       | 48/200 [2:10:06<6:35:15, 156.03s/it]

epoch  47  loss  2.301534481346609  time epoch  154.31097888946533


 24%|██▍       | 49/200 [2:12:41<6:31:40, 155.64s/it]

epoch  48  loss  2.0586811490356927  time epoch  154.7213954925537
epoch  49  loss  2.2463012814521797  time epoch  154.6271686553955


 25%|██▌       | 50/200 [2:16:51<7:40:11, 184.08s/it]

eval mse  62.656764911238554  eval mae  140.0034155390065  eval ssim  0.8407803741700053  time=  95.76872324943542


 26%|██▌       | 51/200 [2:19:26<7:15:12, 175.25s/it]

epoch  50  loss  2.325495649874211  time epoch  154.65447187423706


 26%|██▌       | 52/200 [2:22:00<6:56:48, 168.98s/it]

epoch  51  loss  2.1008713394403467  time epoch  154.3311357498169


 26%|██▋       | 53/200 [2:24:35<6:43:28, 164.69s/it]

epoch  52  loss  2.0709684275090696  time epoch  154.67521214485168


 27%|██▋       | 54/200 [2:27:10<6:33:30, 161.71s/it]

epoch  53  loss  2.2054415859282024  time epoch  154.7761673927307


 28%|██▊       | 55/200 [2:29:44<6:25:48, 159.64s/it]

epoch  54  loss  2.2528974950313563  time epoch  154.81499767303467


 28%|██▊       | 56/200 [2:32:19<6:19:42, 158.21s/it]

epoch  55  loss  2.2991992667317396  time epoch  154.85941433906555


 28%|██▊       | 57/200 [2:34:54<6:14:38, 157.19s/it]

epoch  56  loss  2.151842462271451  time epoch  154.80725646018982


 29%|██▉       | 58/200 [2:37:29<6:10:05, 156.38s/it]

epoch  57  loss  2.1164078205823897  time epoch  154.47747993469238


 30%|██▉       | 59/200 [2:40:03<6:06:06, 155.79s/it]

epoch  58  loss  2.1091720901429656  time epoch  154.43157482147217
epoch  59  loss  2.1311471253633503  time epoch  154.37821650505066


 30%|███       | 60/200 [2:44:12<7:08:57, 183.84s/it]

eval mse  60.14990171201669  eval mae  138.33951349926602  eval ssim  0.835529961465205  time=  94.8445360660553


 30%|███       | 61/200 [2:46:47<6:45:22, 174.98s/it]

epoch  60  loss  2.152662718296051  time epoch  154.3217248916626


 31%|███       | 62/200 [2:49:21<6:28:13, 168.79s/it]

epoch  61  loss  2.226773326098918  time epoch  154.3374683856964


 32%|███▏      | 63/200 [2:51:55<6:15:33, 164.48s/it]

epoch  62  loss  2.1060519151389596  time epoch  154.40716218948364


 32%|███▏      | 64/200 [2:54:30<6:05:56, 161.44s/it]

epoch  63  loss  2.178011911362411  time epoch  154.36801385879517


 32%|███▎      | 65/200 [2:57:05<5:58:54, 159.52s/it]

epoch  64  loss  2.080204079300166  time epoch  155.0159204006195


 33%|███▎      | 66/200 [2:59:40<5:53:25, 158.25s/it]

epoch  65  loss  2.0172631569206714  time epoch  155.28581523895264


 34%|███▎      | 67/200 [3:02:15<5:48:45, 157.34s/it]

epoch  66  loss  2.18409513682127  time epoch  155.2074112892151


 34%|███▍      | 68/200 [3:04:50<5:44:24, 156.55s/it]

epoch  67  loss  2.1557837381958955  time epoch  154.72152495384216


 34%|███▍      | 69/200 [3:07:25<5:40:31, 155.97s/it]

epoch  68  loss  2.153947724401951  time epoch  154.60877466201782
epoch  69  loss  2.069047236442564  time epoch  154.75003480911255


 35%|███▌      | 70/200 [3:11:36<6:39:59, 184.61s/it]

eval mse  47.16546949155771  eval mae  116.25654538877451  eval ssim  0.8592204888843237  time=  96.64838004112244


 36%|███▌      | 71/200 [3:14:10<6:17:11, 175.44s/it]

epoch  70  loss  1.9839581549167635  time epoch  154.03960728645325


 36%|███▌      | 72/200 [3:16:44<6:00:28, 168.98s/it]

epoch  71  loss  2.242801770567894  time epoch  153.88869714736938


 36%|███▋      | 73/200 [3:19:18<5:48:05, 164.45s/it]

epoch  72  loss  2.045995590090752  time epoch  153.9022340774536


 37%|███▋      | 74/200 [3:21:52<5:38:42, 161.29s/it]

epoch  73  loss  2.1576069436967384  time epoch  153.89982914924622
