In [1]:
import torch
import torchvision
import torch.nn as nn
import numpy as np
import os 
from PIL import Image
from collections import defaultdict
from torch.utils.data import Dataset


def get_images_from_path(folder_path):
    # Dictionary to hold images, with subfolders as keys and images sorted by index
    images_by_subfolder = defaultdict(list)
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            if file.endswith('.png') and 'img_' in file:
                try:
                    # Extract the numeric index x from the filename "img_x.png"
                    index = int(file.split('_')[1].split('.')[0])  # Extract x from "img_x.png"
                    # Get the relative subfolder path (excluding the base folder path)
                    subfolder_name = os.path.basename(root)
                    img_path = os.path.join(root, file)
                    # Use context manager to ensure the file is properly closed
                    with Image.open(img_path) as img:
                        # Append the image to the appropriate subfolder's list, sorted by index
                        images_by_subfolder[subfolder_name].append((index, img.convert('L')))
                except (ValueError, IndexError):
                    print(f"Skipping file with unexpected format: {file}")
    # Sort images within each subfolder by their index
    for subfolder, images in images_by_subfolder.items():
        # Sort by the numeric index (first element of the tuple)
        images_by_subfolder[subfolder] = [img for _, img in sorted(images, key=lambda x: x[0])]


    # Assuming images_by_subfolder is already populated
    # Create an empty list to store image data in the desired tensor shape
    tensor_images = []
    # Define the target image size (64x64)
    target_size = (64, 64)
    # Iterate through each subfolder
    for subfolder, images in images_by_subfolder.items():
        folder_images = []
        
        # Iterate through each image in the subfolder
        for img in images:
            # Resize the image to (64, 64)
            img_resized = img.resize(target_size)
            
            # Convert the image to a NumPy array and normalize to [0, 1] range
            img_array = np.array(img_resized) / 255.0  # Convert to float and normalize
            
            # Add a channel dimension (grayscale, so it's 1 channel)
            img_array = np.expand_dims(img_array, axis=-1)  # Shape becomes (64, 64, 1)
            
            # Add the image to the list
            folder_images.append(img_array)
        
        # Append the images from the current subfolder to the main list
        tensor_images.append(folder_images)
    # Convert the list to a NumPy array and then to a PyTorch tensor
    # Convert the list to a numpy array of shape (len(subfolders), num_images_per_subfolder, 64, 64, 1)
    tensor_images = np.array(tensor_images)
    # Convert to a PyTorch tensor (shape will be (batch_size, num_images, 1, 64, 64))
    tensor_images = torch.tensor(tensor_images)
    tensor_images = tensor_images.permute(0, 1, 4, 2, 3)
    return tensor_images


class ImageDataset(Dataset):
    def __init__(self, tensor_images):
        """
        Args:
            tensor_images (Tensor): A tensor of shape (num_subfolders, 15, 1, 64, 64)
        """
        self.tensor_images = tensor_images
        self.num_subfolders = tensor_images.shape[0]
        self.num_images_per_subfolder = tensor_images.shape[1]

    def __len__(self):
        # Return the number of subfolders
        return self.num_subfolders

    def __getitem__(self, idx):
        # Get the subfolder images
        subfolder_images = self.tensor_images[idx]

        # First element: None
        first_element = 0

        # Second element: First 10 images (index 0 to 9)
        second_element = subfolder_images[:10]

        # Third element: Last 5 images (index 10 to 14)
        third_element = subfolder_images[10:]

        # Return as a tuple
        return first_element, second_element, third_element


In [2]:
from torch.utils.data import DataLoader

# Path to the folder containing images
train_dataset = get_images_from_path('/kaggle/input/weather-data/weather_dataset/train')
train_dataset = ImageDataset(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [3]:
test_dataset = get_images_from_path('/kaggle/input/weather-data/weather_dataset/test')
test_dataset = ImageDataset(test_dataset)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

val_dataset = get_images_from_path('/kaggle/input/weather-data/weather_dataset/validation')
val_dataset = ImageDataset(val_dataset)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)


##### 

In [4]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import random
import time
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
import torch.utils.data as data

import argparse
import os
import gzip
import numpy as np

##############################################################################################################
from numpy import *
from numpy.linalg import *
from scipy.special import factorial
from functools import reduce
import torch
import torch.nn as nn

from functools import reduce

__all__ = ['M2K','K2M']

def _apply_axis_left_dot(x, mats):
    assert x.dim() == len(mats)+1
    sizex = x.size()
    k = x.dim()-1
    for i in range(k):
        x = tensordot(mats[k-i-1], x, dim=[1,k])
    x = x.permute([k,]+list(range(k))).contiguous()
    x = x.view(sizex)
    return x

def _apply_axis_right_dot(x, mats):
    assert x.dim() == len(mats)+1
    sizex = x.size()
    k = x.dim()-1
    x = x.permute(list(range(1,k+1))+[0,])
    for i in range(k):
        x = tensordot(x, mats[i], dim=[0,0])
    x = x.contiguous()
    x = x.view(sizex)
    return x

class _MK(nn.Module):
    def __init__(self, shape):
        super(_MK, self).__init__()
        self._size = torch.Size(shape)
        self._dim = len(shape)
        M = []
        invM = []
        assert len(shape) > 0
        j = 0
        for l in shape:
            M.append(zeros((l,l)))
            for i in range(l):
                M[-1][i] = ((arange(l)-(l-1)//2)**i)/factorial(i)
            invM.append(inv(M[-1]))
            self.register_buffer('_M'+str(j), torch.from_numpy(M[-1]))
            self.register_buffer('_invM'+str(j), torch.from_numpy(invM[-1]))
            j += 1

    @property
    def M(self):
        return list(self._buffers['_M'+str(j)] for j in range(self.dim()))
    @property
    def invM(self):
        return list(self._buffers['_invM'+str(j)] for j in range(self.dim()))

    def size(self):
        return self._size
    def dim(self):
        return self._dim
    def _packdim(self, x):
        assert x.dim() >= self.dim()
        if x.dim() == self.dim():
            x = x[newaxis,:]
        x = x.contiguous()
        x = x.view([-1,]+list(x.size()[-self.dim():]))
        return x

    def forward(self):
        pass

class M2K(_MK):
    """
    convert moment matrix to convolution kernel
    Arguments:
        shape (tuple of int): kernel shape
    Usage:
        m2k = M2K([5,5])
        m = torch.randn(5,5,dtype=torch.float64)
        k = m2k(m)
    """
    def __init__(self, shape):
        super(M2K, self).__init__(shape)
    def forward(self, m):
        """
        m (Tensor): torch.size=[...,*self.shape]
        """
        sizem = m.size()
        m = self._packdim(m)
        m = _apply_axis_left_dot(m, self.invM)
        m = m.view(sizem)
        return m

class K2M(_MK):
    """
    convert convolution kernel to moment matrix
    Arguments:
        shape (tuple of int): kernel shape
    Usage:
        k2m = K2M([5,5])
        k = torch.randn(5,5,dtype=torch.float64)
        m = k2m(k)
    """
    def __init__(self, shape):
        super(K2M, self).__init__(shape)
    def forward(self, k):
        """
        k (Tensor): torch.size=[...,*self.shape]
        """
        sizek = k.size()
        k = self._packdim(k)
        k = _apply_axis_left_dot(k, self.M)
        k = k.view(sizek)
        return k


    
def tensordot(a,b,dim):
    """
    tensordot in PyTorch, see numpy.tensordot?
    """
    l = lambda x,y:x*y
    if isinstance(dim,int):
        a = a.contiguous()
        b = b.contiguous()
        sizea = a.size()
        sizeb = b.size()
        sizea0 = sizea[:-dim]
        sizea1 = sizea[-dim:]
        sizeb0 = sizeb[:dim]
        sizeb1 = sizeb[dim:]
        N = reduce(l, sizea1, 1)
        assert reduce(l, sizeb0, 1) == N
    else:
        adims = dim[0]
        bdims = dim[1]
        adims = [adims,] if isinstance(adims, int) else adims
        bdims = [bdims,] if isinstance(bdims, int) else bdims
        adims_ = set(range(a.dim())).difference(set(adims))
        adims_ = list(adims_)
        adims_.sort()
        perma = adims_+adims
        bdims_ = set(range(b.dim())).difference(set(bdims))
        bdims_ = list(bdims_)
        bdims_.sort()
        permb = bdims+bdims_
        a = a.permute(*perma).contiguous()
        b = b.permute(*permb).contiguous()

        sizea = a.size()
        sizeb = b.size()
        sizea0 = sizea[:-len(adims)]
        sizea1 = sizea[-len(adims):]
        sizeb0 = sizeb[:len(bdims)]
        sizeb1 = sizeb[len(bdims):]
        N = reduce(l, sizea1, 1)
        assert reduce(l, sizeb0, 1) == N
    a = a.view([-1,N])
    b = b.view([N,-1])
    c = a@b
    return c.view(sizea0+sizeb1)

##############################################################################################################
import torch
import torch.nn as nn

class PhyCell_Cell(nn.Module):
    def __init__(self, input_dim, F_hidden_dim, kernel_size, bias=1):
        super(PhyCell_Cell, self).__init__()
        self.input_dim  = input_dim
        self.F_hidden_dim = F_hidden_dim
        self.kernel_size = kernel_size
        self.padding     = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias
        
        self.F = nn.Sequential()
        self.F.add_module('conv1', nn.Conv2d(in_channels=input_dim, out_channels=F_hidden_dim, kernel_size=self.kernel_size, stride=(1,1), padding=self.padding))
        self.F.add_module('bn1',nn.GroupNorm( 7 ,F_hidden_dim))        
        self.F.add_module('conv2', nn.Conv2d(in_channels=F_hidden_dim, out_channels=input_dim, kernel_size=(1,1), stride=(1,1), padding=(0,0)))

        self.convgate = nn.Conv2d(in_channels=self.input_dim + self.input_dim,
                              out_channels= self.input_dim,
                              kernel_size=(3,3),
                              padding=(1,1), bias=self.bias)

    def forward(self, x, hidden): # x [batch_size, hidden_dim, height, width]      
        combined = torch.cat([x, hidden], dim=1)  # concatenate along channel axis
        combined_conv = self.convgate(combined)
        K = torch.sigmoid(combined_conv)
        hidden_tilde = hidden + self.F(hidden)        # prediction
        next_hidden = hidden_tilde + K * (x-hidden_tilde)   # correction , Haddamard product     
        return next_hidden

class PhyCell(nn.Module):
    def __init__(self, input_shape, input_dim, F_hidden_dims, n_layers, kernel_size, device):
        super(PhyCell, self).__init__()
        self.input_shape = input_shape
        self.input_dim = input_dim
        self.F_hidden_dims = F_hidden_dims
        self.n_layers = n_layers
        self.kernel_size = kernel_size
        self.H = []  
        self.device = device
             
        cell_list = []
        for i in range(0, self.n_layers):
            cell_list.append(PhyCell_Cell(input_dim=input_dim,
                                          F_hidden_dim=self.F_hidden_dims[i],
                                          kernel_size=self.kernel_size))                                     
        self.cell_list = nn.ModuleList(cell_list)
        
       
    def forward(self, input_, first_timestep=False): # input_ [batch_size, 1, channels, width, height]    
        batch_size = input_.data.size()[0]
        if (first_timestep):   
            self.initHidden(batch_size) # init Hidden at each forward start
              
        for j,cell in enumerate(self.cell_list):
            if j==0: # bottom layer
                self.H[j] = cell(input_, self.H[j])
            else:
                self.H[j] = cell(self.H[j-1],self.H[j])
        
        return self.H , self.H 
    
    def initHidden(self,batch_size):
        self.H = [] 
        for i in range(self.n_layers):
            self.H.append( torch.zeros(batch_size, self.input_dim, self.input_shape[0], self.input_shape[1]).to(self.device) )

    def setHidden(self, H):
        self.H = H

        
class ConvLSTM_Cell(nn.Module):
    def __init__(self, input_shape, input_dim, hidden_dim, kernel_size, bias=1):              
        """
        input_shape: (int, int)
            Height and width of input tensor as (height, width).
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """
        super(ConvLSTM_Cell, self).__init__()
        
        self.height, self.width = input_shape
        self.input_dim  = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding     = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias        = bias
        
        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding, bias=self.bias)
                 
    # we implement LSTM that process only one timestep 
    def forward(self,x, hidden): # x [batch, hidden_dim, width, height]          
        h_cur, c_cur = hidden
        
        combined = torch.cat([x, h_cur], dim=1)  # concatenate along channel axis
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next


class ConvLSTM(nn.Module):
    def __init__(self, input_shape, input_dim, hidden_dims, n_layers, kernel_size,device):
        super(ConvLSTM, self).__init__()
        self.input_shape = input_shape
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.n_layers = n_layers
        self.kernel_size = kernel_size
        self.H, self.C = [],[]   
        self.device = device
        
        cell_list = []
        for i in range(0, self.n_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dims[i-1]
            print('layer ',i,'input dim ', cur_input_dim, ' hidden dim ', self.hidden_dims[i])
            cell_list.append(ConvLSTM_Cell(input_shape=self.input_shape,
                                          input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dims[i],
                                          kernel_size=self.kernel_size))                                     
        self.cell_list = nn.ModuleList(cell_list)
        
       
    def forward(self, input_, first_timestep=False): # input_ [batch_size, 1, channels, width, height]    
        batch_size = input_.data.size()[0]
        if (first_timestep):   
            self.initHidden(batch_size) # init Hidden at each forward start
              
        for j,cell in enumerate(self.cell_list):
            if j==0: # bottom layer
                self.H[j], self.C[j] = cell(input_, (self.H[j],self.C[j]))
            else:
                self.H[j], self.C[j] = cell(self.H[j-1],(self.H[j],self.C[j]))
        
        return (self.H,self.C) , self.H   # (hidden, output)
    
    def initHidden(self,batch_size):
        self.H, self.C = [],[]  
        for i in range(self.n_layers):
            self.H.append( torch.zeros(batch_size,self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device) )
            self.C.append( torch.zeros(batch_size,self.hidden_dims[i], self.input_shape[0], self.input_shape[1]).to(self.device) )
    
    def setHidden(self, hidden):
        H,C = hidden
        self.H, self.C = H,C
 

class dcgan_conv(nn.Module):
    def __init__(self, nin, nout, stride):
        super(dcgan_conv, self).__init__()
        self.main = nn.Sequential(
                nn.Conv2d(in_channels=nin, out_channels=nout, kernel_size=(3,3), stride=stride, padding=1),
                nn.GroupNorm(16,nout),
                nn.LeakyReLU(0.2, inplace=True),
                )

    def forward(self, input):
        return self.main(input)

class dcgan_upconv(nn.Module):
    def __init__(self, nin, nout, stride):
        super(dcgan_upconv, self).__init__()
        if (stride ==2):
            output_padding = 1
        else:
            output_padding = 0
        self.main = nn.Sequential(
                nn.ConvTranspose2d(in_channels=nin,out_channels=nout,kernel_size=(3,3), stride=stride,padding=1,output_padding=output_padding),
                nn.GroupNorm(16,nout),
                nn.LeakyReLU(0.2, inplace=True),
                )

    def forward(self, input):
        return self.main(input)
        
class encoder_E(nn.Module):
    def __init__(self, nc=1, nf=32):
        super(encoder_E, self).__init__()
        # input is (1) x 64 x 64
        self.c1 = dcgan_conv(nc, nf, stride=2) # (32) x 32 x 32
        self.c2 = dcgan_conv(nf, nf, stride=1) # (32) x 32 x 32
        self.c3 = dcgan_conv(nf, 2*nf, stride=2) # (64) x 16 x 16

    def forward(self, input):
        h1 = self.c1(input)  
        h2 = self.c2(h1)    
        h3 = self.c3(h2)
        return h3

class decoder_D(nn.Module):
    def __init__(self, nc=1, nf=32):
        super(decoder_D, self).__init__()
        self.upc1 = dcgan_upconv(2*nf, nf, stride=2) #(32) x 32 x 32
        self.upc2 = dcgan_upconv(nf, nf, stride=1) #(32) x 32 x 32
        self.upc3 = nn.ConvTranspose2d(in_channels=nf,out_channels=nc,kernel_size=(3,3),stride=2,padding=1,output_padding=1)  #(nc) x 64 x 64

    def forward(self, input):      
        d1 = self.upc1(input) 
        d2 = self.upc2(d1)
        d3 = self.upc3(d2)  
        return d3  


class encoder_specific(nn.Module):
    def __init__(self, nc=64, nf=64):
        super(encoder_specific, self).__init__()
        self.c1 = dcgan_conv(nc, nf, stride=1) # (64) x 16 x 16
        self.c2 = dcgan_conv(nf, nf, stride=1) # (64) x 16 x 16

    def forward(self, input):
        h1 = self.c1(input)  
        h2 = self.c2(h1)     
        return h2

class decoder_specific(nn.Module):
    def __init__(self, nc=64, nf=64):
        super(decoder_specific, self).__init__()
        self.upc1 = dcgan_upconv(nf, nf, stride=1) #(64) x 16 x 16
        self.upc2 = dcgan_upconv(nf, nc, stride=1) #(32) x 32 x 32
        
    def forward(self, input):
        d1 = self.upc1(input) 
        d2 = self.upc2(d1)  
        return d2       

        
class EncoderRNN(torch.nn.Module):
    def __init__(self,phycell,convcell, device):
        super(EncoderRNN, self).__init__()
        self.encoder_E = encoder_E()   # general encoder 64x64x1 -> 32x32x32
        self.encoder_Ep = encoder_specific() # specific image encoder 32x32x32 -> 16x16x64
        self.encoder_Er = encoder_specific() 
        self.decoder_Dp = decoder_specific() # specific image decoder 16x16x64 -> 32x32x32 
        self.decoder_Dr = decoder_specific()     
        self.decoder_D = decoder_D()  # general decoder 32x32x32 -> 64x64x1 

        self.encoder_E = self.encoder_E.to(device)
        self.encoder_Ep = self.encoder_Ep.to(device) 
        self.encoder_Er = self.encoder_Er.to(device) 
        self.decoder_Dp = self.decoder_Dp.to(device) 
        self.decoder_Dr = self.decoder_Dr.to(device)               
        self.decoder_D = self.decoder_D.to(device)
        self.phycell = phycell.to(device)
        self.convcell = convcell.to(device)

    def forward(self, input, first_timestep=False, decoding=False):
        input = self.encoder_E(input) # general encoder 64x64x1 -> 32x32x32
    
        if decoding:  # input=None in decoding phase
            input_phys = None
        else:
            input_phys = self.encoder_Ep(input)
        input_conv = self.encoder_Er(input)     

        hidden1, output1 = self.phycell(input_phys, first_timestep)
        hidden2, output2 = self.convcell(input_conv, first_timestep)

        decoded_Dp = self.decoder_Dp(output1[-1])
        decoded_Dr = self.decoder_Dr(output2[-1])
        
        out_phys = torch.sigmoid(self.decoder_D(decoded_Dp)) # partial reconstructions for vizualization
        out_conv = torch.sigmoid(self.decoder_D(decoded_Dr))

        concat = decoded_Dp + decoded_Dr   
        output_image = torch.sigmoid( self.decoder_D(concat ))
        return out_phys, hidden1, output_image, out_phys, out_conv        

##############################################################################################################

device = torch.device("cuda")

constraints = torch.zeros((49,7,7)).to(device)
ind = 0
for i in range(0,7):
    for j in range(0,7):
        constraints[ind,i,j] = 1
        ind +=1    

def train_on_batch(input_tensor, target_tensor, encoder, encoder_optimizer, criterion,teacher_forcing_ratio):                
    encoder_optimizer.zero_grad()
    # input_tensor : torch.Size([batch_size, input_length, channels, cols, rows])
    input_length  = input_tensor.size(1)
    target_length = target_tensor.size(1)
    loss = 0
    for ei in range(input_length-1): 
        encoder_output, encoder_hidden, output_image,_,_ = encoder(input_tensor[:,ei,:,:,:].float(), (ei==0) )
        loss += criterion(output_image,input_tensor[:,ei+1,:,:,:])

    decoder_input = input_tensor[:,-1,:,:,:] # first decoder input = last image of input sequence
    
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False 
    for di in range(target_length):
        decoder_output, decoder_hidden, output_image,_,_ = encoder(decoder_input.float())
        target = target_tensor[:,di,:,:,:]
        loss += criterion(output_image,target)
        if use_teacher_forcing:
            decoder_input = target # Teacher forcing    
        else:
            decoder_input = output_image

    # Moment regularization  # encoder.phycell.cell_list[0].F.conv1.weight # size (nb_filters,in_channels,7,7)
    k2m = K2M([7,7]).to(device)
    for b in range(0,encoder.phycell.cell_list[0].input_dim):
        filters = encoder.phycell.cell_list[0].F.conv1.weight[:,b,:,:] # (nb_filters,7,7)     
        m = k2m(filters.double()) 
        m  = m.float()   
        loss += criterion(m, constraints) # constrains is a precomputed matrix   
    loss.backward()
    encoder_optimizer.step()
    return loss.item() / target_length


def trainIters(encoder, nepochs, print_every=10,eval_every=10,name=''):
    train_losses = []
    best_mse = float('inf')

    encoder_optimizer = torch.optim.Adam(encoder.parameters(),lr=0.001)
    scheduler_enc = ReduceLROnPlateau(encoder_optimizer, mode='min', patience=2,factor=0.1,verbose=True)
    criterion = nn.MSELoss()
    
    for epoch in tqdm(range(0, nepochs)):        
        t0 = time.time()
        loss_epoch = 0
        teacher_forcing_ratio = np.maximum(0 , 1 - epoch * 0.003) 
        
        for i, out in enumerate(train_loader, 0):
            input_tensor = out[1].to(device)
            target_tensor = out[2].to(device)
            loss = train_on_batch(input_tensor.float(), target_tensor.float(), encoder, encoder_optimizer, criterion, teacher_forcing_ratio)                                   
            loss_epoch += loss
                      
        train_losses.append(loss_epoch)        
        if (epoch+1) % print_every == 0:
            print('epoch ',epoch,  ' loss ',loss_epoch, ' time epoch ',time.time()-t0)
            
        if (epoch+1) % eval_every == 0:
            mse, mae,ssim = evaluate(encoder,test_loader) 
            scheduler_enc.step(mse)                   
            torch.save(encoder.state_dict(),'encoder_{}.pth'.format(name))                           
    return train_losses

    
def evaluate(encoder,loader):
    total_mse, total_mae,total_ssim,total_bce = 0,0,0,0
    t0 = time.time()
    with torch.no_grad():
        for i, out in enumerate(loader, 0):
            input_tensor = out[1].to(device)
            target_tensor = out[2].to(device)
            input_length = input_tensor.size()[1]
            target_length = target_tensor.size()[1]

            for ei in range(input_length-1):
                encoder_output, encoder_hidden, _,_,_  = encoder(input_tensor[:,ei,:,:,:].float(), (ei==0))

            decoder_input = input_tensor[:,-1,:,:,:] # first decoder input= last image of input sequence
            predictions = []

            for di in range(target_length):
                decoder_output, decoder_hidden, output_image,_,_ = encoder(decoder_input.float(), False, False)
                decoder_input = output_image
                predictions.append(output_image.cpu())

            input = input_tensor.cpu().numpy()
            target = target_tensor.cpu().numpy()
            predictions =  np.stack(predictions) # (10, batch_size, 1, 64, 64)
            predictions = predictions.swapaxes(0,1)  # (batch_size,10, 1, 64, 64)

            mse_batch = np.mean((predictions-target)**2 , axis=(0,1,2)).sum()
            mae_batch = np.mean(np.abs(predictions-target) ,  axis=(0,1,2)).sum() 
            total_mse += mse_batch
            total_mae += mae_batch
            for a in range(0,target.shape[0]):
                for b in range(0,target.shape[1]):
                    total_ssim += ssim(target[a,b,0,], predictions[a,b,0,], data_range=1.0) / (target.shape[0]*target.shape[1]) 

            
            cross_entropy = -target*np.log(predictions) - (1-target) * np.log(1-predictions)
            cross_entropy = cross_entropy.sum()
            cross_entropy = cross_entropy / (64*target_length)
            total_bce +=  cross_entropy
     
    print('eval mse ', total_mse/len(loader),  ' eval mae ', total_mae/len(loader),' eval ssim ',total_ssim/len(loader), ' time= ', time.time()-t0)        
    return total_mse/len(loader),  total_mae/len(loader), total_ssim/len(loader)


phycell  =  PhyCell(input_shape=(16,16), input_dim=64, F_hidden_dims=[49], n_layers=1, kernel_size=(7,7), device=device) 
convcell =  ConvLSTM(input_shape=(16,16), input_dim=64, hidden_dims=[128,128,64], n_layers=3, kernel_size=(3,3), device=device)   
encoder  = EncoderRNN(phycell, convcell, device)
  
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
   
print('phycell ' , count_parameters(phycell) )    
print('convcell ' , count_parameters(convcell) ) 
print('encoder ' , count_parameters(encoder) ) 

trainIters(encoder,200,print_every=1,eval_every=10,name="phydnet_eval")

#encoder.load_state_dict(torch.load('save/encoder_phydnet.pth'))
#encoder.eval()
#mse, mae,ssim = evaluate(encoder,test_loader) 
torch.save(encoder.state_dict(),'encoder_{}.pth'.format('phydnet_weather_200'))


  return sum(p.numel() for p in model.parameters() if p.requires_grad)


layer  0 input dim  64  hidden dim  128
layer  1 input dim  128  hidden dim  128
layer  2 input dim  128  hidden dim  64
phycell  230803
convcell  2508032
encoder  3091732


  0%|          | 1/200 [01:23<4:38:24, 83.94s/it]

epoch  0  loss  29.871303391456607  time epoch  83.94046664237976


  1%|          | 2/200 [02:47<4:35:26, 83.47s/it]

epoch  1  loss  18.21758080720901  time epoch  83.13774037361145


  2%|▏         | 3/200 [04:10<4:33:15, 83.22s/it]

epoch  2  loss  16.350564229488374  time epoch  82.93031549453735


  2%|▏         | 4/200 [05:33<4:31:44, 83.19s/it]

epoch  3  loss  14.98818594217301  time epoch  83.12866950035095


  2%|▎         | 5/200 [06:56<4:30:01, 83.09s/it]

epoch  4  loss  14.152812612056737  time epoch  82.90316200256348


  3%|▎         | 6/200 [08:19<4:28:35, 83.07s/it]

epoch  5  loss  13.630245256423947  time epoch  83.03374648094177


  4%|▎         | 7/200 [09:42<4:27:08, 83.05s/it]

epoch  6  loss  13.092449259757991  time epoch  83.0111517906189


  4%|▍         | 8/200 [11:05<4:25:48, 83.07s/it]

epoch  7  loss  12.690856075286861  time epoch  83.10028100013733


  4%|▍         | 9/200 [12:28<4:24:23, 83.05s/it]

epoch  8  loss  12.244361352920532  time epoch  83.02533769607544
epoch  9  loss  11.824924880266186  time epoch  83.06010890007019


  5%|▌         | 10/200 [14:14<4:45:17, 90.09s/it]

eval mse  25.366846964372996  eval mae  213.74614238143897  eval ssim  0.5070114920819011  time=  22.771437883377075


  6%|▌         | 11/200 [15:37<4:36:59, 87.93s/it]

epoch  10  loss  11.654116535186766  time epoch  83.03566288948059


  6%|▌         | 12/200 [17:00<4:30:43, 86.40s/it]

epoch  11  loss  11.15762325525283  time epoch  82.88525128364563


  6%|▋         | 13/200 [18:22<4:25:59, 85.35s/it]

epoch  12  loss  10.957501822710032  time epoch  82.9207215309143


  7%|▋         | 14/200 [19:45<4:22:20, 84.63s/it]

epoch  13  loss  10.539687609672546  time epoch  82.96916556358337


  8%|▊         | 15/200 [21:08<4:19:19, 84.10s/it]

epoch  14  loss  10.207694900035861  time epoch  82.88747262954712


  8%|▊         | 16/200 [22:31<4:17:02, 83.82s/it]

epoch  15  loss  9.972348362207413  time epoch  83.15480089187622


  8%|▊         | 17/200 [23:55<4:14:59, 83.61s/it]

epoch  16  loss  9.771883672475816  time epoch  83.1102523803711


  9%|▉         | 18/200 [25:18<4:13:08, 83.45s/it]

epoch  17  loss  9.567624545097354  time epoch  83.0951623916626


 10%|▉         | 19/200 [26:41<4:11:23, 83.34s/it]

epoch  18  loss  9.313882583379746  time epoch  83.05985450744629
epoch  19  loss  8.998328870534898  time epoch  83.03743410110474


 10%|█         | 20/200 [28:27<4:30:29, 90.16s/it]

eval mse  29.542836987367533  eval mae  224.48502868078737  eval ssim  0.5410876928536019  time=  23.001746654510498


 10%|█         | 21/200 [29:50<4:22:37, 88.03s/it]

epoch  20  loss  8.825432169437411  time epoch  83.06408905982971


 11%|█         | 22/200 [31:13<4:16:45, 86.55s/it]

epoch  21  loss  8.67826247811318  time epoch  83.09179759025574


 12%|█▏        | 23/200 [32:36<4:12:11, 85.49s/it]

epoch  22  loss  8.442799919843672  time epoch  83.015629529953


 12%|█▏        | 24/200 [33:59<4:08:33, 84.74s/it]

epoch  23  loss  8.370255285501477  time epoch  82.97607183456421


 12%|█▎        | 25/200 [35:22<4:05:36, 84.21s/it]

epoch  24  loss  8.20389031171799  time epoch  82.98443365097046


 13%|█▎        | 26/200 [36:45<4:03:12, 83.86s/it]

epoch  25  loss  8.021141743659973  time epoch  83.05517983436584


 14%|█▎        | 27/200 [38:08<4:01:03, 83.60s/it]

epoch  26  loss  7.946120965480805  time epoch  82.98895001411438


 14%|█▍        | 28/200 [39:31<3:59:20, 83.49s/it]

epoch  27  loss  7.746756255626679  time epoch  83.22933769226074


 14%|█▍        | 29/200 [40:54<3:57:34, 83.36s/it]

epoch  28  loss  7.6348612725734695  time epoch  83.05366349220276
epoch  29  loss  7.406646370887755  time epoch  82.91812562942505


 15%|█▌        | 30/200 [42:40<4:15:11, 90.07s/it]

eval mse  19.974749601487446  eval mae  184.60187545970885  eval ssim  0.5739129104534735  time=  22.773363828659058


 16%|█▌        | 31/200 [44:03<4:07:36, 87.91s/it]

epoch  30  loss  7.282423692941666  time epoch  82.86849808692932


 16%|█▌        | 32/200 [45:26<4:01:53, 86.39s/it]

epoch  31  loss  7.244899392127991  time epoch  82.8405294418335


 16%|█▋        | 33/200 [46:49<3:57:30, 85.33s/it]

epoch  32  loss  7.078332978487015  time epoch  82.86679697036743


 17%|█▋        | 34/200 [48:11<3:53:57, 84.57s/it]

epoch  33  loss  7.127733272314073  time epoch  82.77445387840271


 18%|█▊        | 35/200 [49:34<3:51:13, 84.08s/it]

epoch  34  loss  6.895451366901395  time epoch  82.94213509559631


 18%|█▊        | 36/200 [50:57<3:48:45, 83.69s/it]

epoch  35  loss  6.9670439720153805  time epoch  82.7945556640625


 18%|█▊        | 37/200 [52:20<3:46:41, 83.45s/it]

epoch  36  loss  6.916311538219452  time epoch  82.87028074264526


 19%|█▉        | 38/200 [53:43<3:44:55, 83.31s/it]

epoch  37  loss  6.677091908454893  time epoch  82.97970676422119


 20%|█▉        | 39/200 [55:06<3:43:30, 83.29s/it]

epoch  38  loss  6.5703058660030385  time epoch  83.26438736915588
epoch  39  loss  6.470276951789856  time epoch  83.1836256980896


 20%|██        | 40/200 [56:52<4:00:19, 90.12s/it]

eval mse  19.489778226708793  eval mae  181.6136941321822  eval ssim  0.5832445255983348  time=  22.83205270767212


 20%|██        | 41/200 [58:15<3:53:03, 87.95s/it]

epoch  40  loss  6.439573687314988  time epoch  82.8770318031311


 21%|██        | 42/200 [59:38<3:47:38, 86.44s/it]

epoch  41  loss  6.293355318903919  time epoch  82.93401789665222


 22%|██▏       | 43/200 [1:01:01<3:43:26, 85.39s/it]

epoch  42  loss  6.237348997592925  time epoch  82.94126534461975


 22%|██▏       | 44/200 [1:02:24<3:40:10, 84.68s/it]

epoch  43  loss  6.112565249204636  time epoch  83.02056622505188


 22%|██▎       | 45/200 [1:03:47<3:37:24, 84.16s/it]

epoch  44  loss  6.131030717492103  time epoch  82.92735528945923


 23%|██▎       | 46/200 [1:05:10<3:35:06, 83.81s/it]

epoch  45  loss  6.024691745638848  time epoch  83.00178408622742


 24%|██▎       | 47/200 [1:06:33<3:33:05, 83.57s/it]

epoch  46  loss  5.952096685767172  time epoch  82.99645400047302


 24%|██▍       | 48/200 [1:07:56<3:31:13, 83.38s/it]

epoch  47  loss  5.815231963992117  time epoch  82.9432520866394


 24%|██▍       | 49/200 [1:09:19<3:29:35, 83.28s/it]

epoch  48  loss  5.81156147122383  time epoch  83.04849290847778
epoch  49  loss  5.71274111866951  time epoch  82.96325182914734


 25%|██▌       | 50/200 [1:11:05<3:45:18, 90.12s/it]

eval mse  20.161224437054894  eval mae  191.18570369143666  eval ssim  0.567387004726482  time=  23.09364604949951


 26%|██▌       | 51/200 [1:12:28<3:38:27, 87.97s/it]

epoch  50  loss  5.671284583210947  time epoch  82.93890690803528


 26%|██▌       | 52/200 [1:13:51<3:33:06, 86.40s/it]

epoch  51  loss  5.529719978570939  time epoch  82.72927927970886


 26%|██▋       | 53/200 [1:15:14<3:29:01, 85.32s/it]

epoch  52  loss  5.466522762179376  time epoch  82.80327081680298


 27%|██▋       | 54/200 [1:16:36<3:25:54, 84.62s/it]

epoch  53  loss  5.466482177376747  time epoch  82.97891402244568


 28%|██▊       | 55/200 [1:17:59<3:23:09, 84.07s/it]

epoch  54  loss  5.3385376930236825  time epoch  82.78224301338196


 28%|██▊       | 56/200 [1:19:22<3:20:49, 83.68s/it]

epoch  55  loss  5.246641534566876  time epoch  82.7718436717987


 28%|██▊       | 57/200 [1:20:45<3:18:40, 83.36s/it]

epoch  56  loss  5.191715970635415  time epoch  82.62450289726257


 29%|██▉       | 58/200 [1:22:08<3:16:55, 83.21s/it]

epoch  57  loss  5.1817360728979125  time epoch  82.84739446640015


 30%|██▉       | 59/200 [1:23:30<3:15:14, 83.08s/it]

epoch  58  loss  5.065925797820092  time epoch  82.7720296382904
epoch  59  loss  5.073302233219149  time epoch  82.87310528755188


 30%|███       | 60/200 [1:25:16<3:29:45, 89.90s/it]

eval mse  24.418881055009024  eval mae  207.28680514420816  eval ssim  0.5604307952285238  time=  22.903430700302124


 30%|███       | 61/200 [1:26:39<3:23:19, 87.77s/it]

epoch  60  loss  5.100867715477946  time epoch  82.79829001426697


 31%|███       | 62/200 [1:28:02<3:18:26, 86.28s/it]

epoch  61  loss  4.901067391037943  time epoch  82.79894924163818


 32%|███▏      | 63/200 [1:29:25<3:14:43, 85.28s/it]

epoch  62  loss  4.830739018321037  time epoch  82.96021461486816


 32%|███▏      | 64/200 [1:30:47<3:11:33, 84.51s/it]

epoch  63  loss  4.7799711674451855  time epoch  82.69607853889465


 32%|███▎      | 65/200 [1:32:10<3:08:56, 83.97s/it]

epoch  64  loss  4.836532938480376  time epoch  82.72743129730225


 33%|███▎      | 66/200 [1:33:33<3:06:42, 83.60s/it]

epoch  65  loss  4.693202024698261  time epoch  82.72004866600037


 34%|███▎      | 67/200 [1:34:55<3:04:40, 83.32s/it]

epoch  66  loss  4.613938888907433  time epoch  82.65378618240356


 34%|███▍      | 68/200 [1:36:18<3:02:54, 83.14s/it]

epoch  67  loss  4.599037656188009  time epoch  82.73449802398682


 34%|███▍      | 69/200 [1:37:41<3:01:14, 83.01s/it]

epoch  68  loss  4.523393604159352  time epoch  82.69787454605103
epoch  69  loss  4.506723162531853  time epoch  82.72149133682251


 35%|███▌      | 70/200 [1:39:26<3:14:22, 89.71s/it]

eval mse  19.896196026123043  eval mae  182.09819245532464  eval ssim  0.5929634178348308  time=  22.598747730255127


 36%|███▌      | 71/200 [1:40:49<3:08:18, 87.58s/it]

epoch  70  loss  4.4279025495052355  time epoch  82.61571788787842


 36%|███▌      | 72/200 [1:42:12<3:03:41, 86.11s/it]

epoch  71  loss  4.413888201117514  time epoch  82.66659307479858


 36%|███▋      | 73/200 [1:43:34<3:00:03, 85.06s/it]

epoch  72  loss  4.412725624442103  time epoch  82.62166786193848


 37%|███▋      | 74/200 [1:44:57<2:57:06, 84.34s/it]

epoch  73  loss  4.396558463573455  time epoch  82.64360523223877


 38%|███▊      | 75/200 [1:46:19<2:54:38, 83.82s/it]

epoch  74  loss  4.394533699750901  time epoch  82.62366962432861


 38%|███▊      | 76/200 [1:47:42<2:52:30, 83.47s/it]

epoch  75  loss  4.385345414280892  time epoch  82.64598870277405


 38%|███▊      | 77/200 [1:49:05<2:50:37, 83.23s/it]

epoch  76  loss  4.382226002216338  time epoch  82.67686772346497


 39%|███▉      | 78/200 [1:50:27<2:48:50, 83.04s/it]

epoch  77  loss  4.379563611745832  time epoch  82.57249975204468


 40%|███▉      | 79/200 [1:51:50<2:47:14, 82.93s/it]

epoch  78  loss  4.369621455669403  time epoch  82.68258571624756
epoch  79  loss  4.35564156472683  time epoch  82.65882635116577


 40%|████      | 80/200 [1:53:35<2:59:17, 89.64s/it]

eval mse  19.683888271682623  eval mae  178.36042419121654  eval ssim  0.6025908158435054  time=  22.610783576965332


 40%|████      | 81/200 [1:54:58<2:53:39, 87.56s/it]

epoch  80  loss  4.358685180544855  time epoch  82.70932459831238


 41%|████      | 82/200 [1:56:21<2:49:16, 86.07s/it]

epoch  81  loss  4.333789169788361  time epoch  82.58466935157776


 42%|████▏     | 83/200 [1:57:43<2:45:48, 85.03s/it]

epoch  82  loss  4.334989991784096  time epoch  82.61390852928162


 42%|████▏     | 84/200 [1:59:06<2:42:59, 84.30s/it]

epoch  83  loss  4.324621230363843  time epoch  82.60051465034485


 42%|████▎     | 85/200 [2:00:29<2:40:50, 83.91s/it]

epoch  84  loss  4.31522336602211  time epoch  83.00227332115173


 43%|████▎     | 86/200 [2:01:52<2:38:47, 83.57s/it]

epoch  85  loss  4.299019882082939  time epoch  82.77752232551575


 44%|████▎     | 87/200 [2:03:14<2:36:47, 83.25s/it]

epoch  86  loss  4.304493662714957  time epoch  82.50163865089417


 44%|████▍     | 88/200 [2:04:37<2:35:00, 83.04s/it]

epoch  87  loss  4.278626680374147  time epoch  82.55661392211914


 44%|████▍     | 89/200 [2:06:00<2:33:32, 83.00s/it]

epoch  88  loss  4.282988256216052  time epoch  82.88559126853943
epoch  89  loss  4.263133040070533  time epoch  82.92275023460388


 45%|████▌     | 90/200 [2:07:45<2:44:42, 89.84s/it]

eval mse  20.015683286079646  eval mae  179.0577033883536  eval ssim  0.6027655938718669  time=  22.85414958000183


 46%|████▌     | 91/200 [2:09:08<2:39:25, 87.76s/it]

epoch  90  loss  4.267397806048393  time epoch  82.8982892036438


 46%|████▌     | 92/200 [2:10:31<2:35:22, 86.32s/it]

epoch  91  loss  4.242937353253365  time epoch  82.94634342193604


 46%|████▋     | 93/200 [2:11:54<2:32:04, 85.27s/it]

epoch  92  loss  4.244135695695875  time epoch  82.84061551094055


 47%|████▋     | 94/200 [2:13:17<2:29:22, 84.56s/it]

epoch  93  loss  4.216614234447479  time epoch  82.88242745399475


 48%|████▊     | 95/200 [2:14:40<2:27:06, 84.06s/it]

epoch  94  loss  4.21401962041855  time epoch  82.89650893211365


 48%|████▊     | 96/200 [2:16:03<2:25:03, 83.69s/it]

epoch  95  loss  4.194008332490921  time epoch  82.81397080421448


 48%|████▊     | 97/200 [2:17:26<2:23:15, 83.45s/it]

epoch  96  loss  4.179788494110108  time epoch  82.89028024673462


 49%|████▉     | 98/200 [2:18:48<2:21:33, 83.27s/it]

epoch  97  loss  4.180729851126673  time epoch  82.85785555839539


 50%|████▉     | 99/200 [2:20:11<2:19:58, 83.15s/it]

epoch  98  loss  4.15275664329529  time epoch  82.86655044555664
epoch  99  loss  4.1580074846744575  time epoch  83.13238644599915


 50%|█████     | 100/200 [2:21:58<2:30:06, 90.07s/it]

eval mse  20.500065316310394  eval mae  180.47409931445898  eval ssim  0.6027407911405952  time=  23.048158884048462


 50%|█████     | 101/200 [2:23:21<2:25:10, 87.99s/it]

epoch  100  loss  4.136803871393203  time epoch  83.12140798568726


 51%|█████     | 102/200 [2:24:44<2:21:18, 86.52s/it]

epoch  101  loss  4.137809520959857  time epoch  83.09002447128296


 52%|█████▏    | 103/200 [2:26:07<2:18:08, 85.45s/it]

epoch  102  loss  4.134835258126259  time epoch  82.96269130706787


 52%|█████▏    | 104/200 [2:27:30<2:15:30, 84.69s/it]

epoch  103  loss  4.1385493308305765  time epoch  82.91793203353882


 52%|█████▎    | 105/200 [2:28:52<2:13:13, 84.14s/it]

epoch  104  loss  4.137105199694636  time epoch  82.86413955688477


 53%|█████▎    | 106/200 [2:30:16<2:11:20, 83.84s/it]

epoch  105  loss  4.133338868618014  time epoch  83.11858034133911


 54%|█████▎    | 107/200 [2:31:38<2:09:27, 83.52s/it]

epoch  106  loss  4.132239651679992  time epoch  82.7664213180542


 54%|█████▍    | 108/200 [2:33:01<2:07:42, 83.29s/it]

epoch  107  loss  4.135428556799889  time epoch  82.77181696891785


 55%|█████▍    | 109/200 [2:34:24<2:06:04, 83.12s/it]

epoch  108  loss  4.13126765191555  time epoch  82.7189028263092
epoch  109  loss  4.1257226973772045  time epoch  82.72488808631897


 55%|█████▌    | 110/200 [2:36:09<2:14:42, 89.81s/it]

eval mse  20.38212116004333  eval mae  180.02484022762192  eval ssim  0.6022605786955287  time=  22.664498805999756


 56%|█████▌    | 111/200 [2:37:32<2:10:04, 87.69s/it]

epoch  110  loss  4.129494208097458  time epoch  82.74765729904175


 56%|█████▌    | 112/200 [2:38:55<2:06:22, 86.16s/it]

epoch  111  loss  4.123183801770209  time epoch  82.5857207775116


 56%|█████▋    | 113/200 [2:40:17<2:03:28, 85.15s/it]

epoch  112  loss  4.123659694194794  time epoch  82.79704260826111


 57%|█████▋    | 114/200 [2:41:40<2:00:57, 84.39s/it]

epoch  113  loss  4.123670974373817  time epoch  82.59784245491028


 57%|█████▊    | 115/200 [2:43:03<1:58:49, 83.87s/it]

epoch  114  loss  4.126704987883568  time epoch  82.67317938804626


 58%|█████▊    | 116/200 [2:44:25<1:56:54, 83.50s/it]

epoch  115  loss  4.113927081227303  time epoch  82.64000988006592


 58%|█████▊    | 117/200 [2:45:48<1:55:09, 83.25s/it]

epoch  116  loss  4.123156079649925  time epoch  82.65996146202087


 59%|█████▉    | 118/200 [2:47:11<1:53:31, 83.07s/it]

epoch  117  loss  4.120359358191492  time epoch  82.65452885627747


 60%|█████▉    | 119/200 [2:48:33<1:51:57, 82.93s/it]

epoch  118  loss  4.115887671709059  time epoch  82.61011385917664
epoch  119  loss  4.118467319011689  time epoch  82.66501331329346


 60%|██████    | 120/200 [2:50:19<1:59:30, 89.63s/it]

eval mse  20.45201426639322  eval mae  180.26802719537025  eval ssim  0.6024137987758376  time=  22.571511030197144


 60%|██████    | 121/200 [2:51:41<1:55:14, 87.52s/it]

epoch  120  loss  4.123876431584358  time epoch  82.60002017021179


 61%|██████    | 122/200 [2:53:04<1:51:51, 86.04s/it]

epoch  121  loss  4.113316768407821  time epoch  82.58565592765808


 62%|██████▏   | 123/200 [2:54:26<1:49:05, 85.01s/it]

epoch  122  loss  4.110655492544176  time epoch  82.60496211051941


 62%|██████▏   | 124/200 [2:55:49<1:46:48, 84.32s/it]

epoch  123  loss  4.112520274519921  time epoch  82.7008969783783


 62%|██████▎   | 125/200 [2:57:12<1:44:42, 83.77s/it]

epoch  124  loss  4.103931888937949  time epoch  82.4942696094513


 63%|██████▎   | 126/200 [2:58:34<1:42:50, 83.38s/it]

epoch  125  loss  4.106375795602798  time epoch  82.47521567344666


 64%|██████▎   | 127/200 [2:59:56<1:41:07, 83.11s/it]

epoch  126  loss  4.0965618789196006  time epoch  82.47836375236511


 64%|██████▍   | 128/200 [3:01:19<1:39:31, 82.94s/it]

epoch  127  loss  4.110226556658745  time epoch  82.5212230682373


 64%|██████▍   | 129/200 [3:02:42<1:37:59, 82.81s/it]

epoch  128  loss  4.1005742192268375  time epoch  82.52722454071045
epoch  129  loss  4.098058357834816  time epoch  82.50554060935974


 65%|██████▌   | 130/200 [3:04:27<1:44:26, 89.52s/it]

eval mse  20.461360421853538  eval mae  180.16810672317357  eval ssim  0.6020201041514779  time=  22.628646850585938


 66%|██████▌   | 131/200 [3:05:49<1:40:33, 87.44s/it]

epoch  130  loss  4.09052456021309  time epoch  82.59286856651306


 66%|██████▌   | 132/200 [3:07:12<1:37:26, 85.98s/it]

epoch  131  loss  4.101493248343467  time epoch  82.55817866325378


 66%|██████▋   | 133/200 [3:08:34<1:34:52, 84.96s/it]

epoch  132  loss  4.10227136015892  time epoch  82.59801197052002


 67%|██████▋   | 134/200 [3:09:57<1:32:40, 84.25s/it]

epoch  133  loss  4.100677055120468  time epoch  82.58063864707947


 68%|██████▊   | 135/200 [3:11:20<1:30:43, 83.74s/it]

epoch  134  loss  4.0952956706285475  time epoch  82.54866027832031


 68%|██████▊   | 136/200 [3:12:42<1:28:56, 83.38s/it]

epoch  135  loss  4.1008059144020095  time epoch  82.55395102500916


 68%|██████▊   | 137/200 [3:14:05<1:27:20, 83.18s/it]

epoch  136  loss  4.103809195756912  time epoch  82.7008113861084


 69%|██████▉   | 138/200 [3:15:27<1:25:47, 83.02s/it]

epoch  137  loss  4.095106986165047  time epoch  82.63635730743408


 70%|██████▉   | 139/200 [3:16:50<1:24:16, 82.89s/it]

epoch  138  loss  4.099648013710976  time epoch  82.57925462722778
epoch  139  loss  4.1105956703424456  time epoch  82.4810619354248


 70%|███████   | 140/200 [3:18:35<1:29:32, 89.54s/it]

eval mse  20.513727243233657  eval mae  180.25702346318576  eval ssim  0.6025281771177821  time=  22.545645713806152


 70%|███████   | 141/200 [3:19:58<1:25:58, 87.43s/it]

epoch  140  loss  4.100688415765765  time epoch  82.49700498580933


 71%|███████   | 142/200 [3:21:20<1:23:05, 85.96s/it]

epoch  141  loss  4.107346284389496  time epoch  82.55228734016418


 72%|███████▏  | 143/200 [3:22:43<1:20:40, 84.93s/it]

epoch  142  loss  4.10295777320862  time epoch  82.49746680259705


 72%|███████▏  | 144/200 [3:24:05<1:18:34, 84.19s/it]

epoch  143  loss  4.094086304306984  time epoch  82.4779405593872


 72%|███████▎  | 145/200 [3:25:28<1:16:44, 83.71s/it]

epoch  144  loss  4.101066923141478  time epoch  82.59172677993774


 73%|███████▎  | 146/200 [3:26:50<1:15:01, 83.37s/it]

epoch  145  loss  4.1022306889295574  time epoch  82.560063123703


 74%|███████▎  | 147/200 [3:28:13<1:13:25, 83.12s/it]

epoch  146  loss  4.10077896416187  time epoch  82.53672170639038


 74%|███████▍  | 148/200 [3:29:35<1:11:53, 82.95s/it]

epoch  147  loss  4.09946447312832  time epoch  82.54132747650146


 74%|███████▍  | 149/200 [3:30:58<1:10:23, 82.81s/it]

epoch  148  loss  4.101975050568583  time epoch  82.4794979095459
epoch  149  loss  4.097974467277524  time epoch  82.56442785263062


 75%|███████▌  | 150/200 [3:32:43<1:14:34, 89.50s/it]

eval mse  20.49413575373794  eval mae  180.2377822947954  eval ssim  0.6025435884624747  time=  22.514856576919556


 76%|███████▌  | 151/200 [3:34:06<1:11:23, 87.42s/it]

epoch  150  loss  4.104212442040445  time epoch  82.55867552757263


 76%|███████▌  | 152/200 [3:35:28<1:08:45, 85.94s/it]

epoch  151  loss  4.112181219458579  time epoch  82.5059130191803


 76%|███████▋  | 153/200 [3:36:51<1:06:30, 84.91s/it]

epoch  152  loss  4.098515507578851  time epoch  82.48847436904907


 77%|███████▋  | 154/200 [3:38:13<1:04:32, 84.19s/it]

epoch  153  loss  4.099141773581504  time epoch  82.53061366081238


 78%|███████▊  | 155/200 [3:39:35<1:02:44, 83.66s/it]

epoch  154  loss  4.104522299766541  time epoch  82.40650987625122


 78%|███████▊  | 156/200 [3:40:58<1:01:06, 83.33s/it]

epoch  155  loss  4.113370403647424  time epoch  82.56056213378906


 78%|███████▊  | 157/200 [3:42:20<59:31, 83.07s/it]  

epoch  156  loss  4.104977908730508  time epoch  82.44697523117065


 79%|███████▉  | 158/200 [3:43:43<58:02, 82.92s/it]

epoch  157  loss  4.105924460291862  time epoch  82.58414840698242


 80%|███████▉  | 159/200 [3:45:06<56:34, 82.79s/it]

epoch  158  loss  4.108896598219871  time epoch  82.49127960205078
epoch  159  loss  4.102937623858452  time epoch  82.5196316242218


 80%|████████  | 160/200 [3:46:51<59:38, 89.46s/it]

eval mse  20.513110505181857  eval mae  180.26327800352203  eval ssim  0.6025020643658366  time=  22.45680284500122


 80%|████████  | 161/200 [3:48:13<56:47, 87.37s/it]

epoch  160  loss  4.120111766457559  time epoch  82.49202299118042


 81%|████████  | 162/200 [3:49:36<54:25, 85.93s/it]

epoch  161  loss  4.110804328322409  time epoch  82.55754280090332


 82%|████████▏ | 163/200 [3:50:58<52:20, 84.89s/it]

epoch  162  loss  4.10560349225998  time epoch  82.47736263275146


 82%|████████▏ | 164/200 [3:52:21<50:29, 84.16s/it]

epoch  163  loss  4.098942026495933  time epoch  82.44434571266174


 82%|████████▎ | 165/200 [3:53:43<48:48, 83.66s/it]

epoch  164  loss  4.10663000047207  time epoch  82.5075933933258


 83%|████████▎ | 166/200 [3:55:06<47:13, 83.33s/it]

epoch  165  loss  4.116805288195612  time epoch  82.544504404068


 84%|████████▎ | 167/200 [3:56:28<45:41, 83.07s/it]

epoch  166  loss  4.110636338591576  time epoch  82.47596716880798


 84%|████████▍ | 168/200 [3:57:51<44:12, 82.89s/it]

epoch  167  loss  4.099751546978951  time epoch  82.44954586029053


 84%|████████▍ | 169/200 [3:59:13<42:45, 82.75s/it]

epoch  168  loss  4.109323477745057  time epoch  82.41638922691345
epoch  169  loss  4.106326788663864  time epoch  82.3900830745697


 85%|████████▌ | 170/200 [4:00:58<44:41, 89.38s/it]

eval mse  20.518553323148367  eval mae  180.2834730193711  eval ssim  0.6024024772587422  time=  22.44735312461853


 86%|████████▌ | 171/200 [4:02:20<42:12, 87.32s/it]

epoch  170  loss  4.110209396481514  time epoch  82.5121841430664


 86%|████████▌ | 172/200 [4:03:43<40:03, 85.85s/it]

epoch  171  loss  4.108400875329973  time epoch  82.4166419506073


 86%|████████▋ | 173/200 [4:05:05<38:10, 84.83s/it]

epoch  172  loss  4.105450427532197  time epoch  82.44637155532837


 87%|████████▋ | 174/200 [4:06:28<36:27, 84.13s/it]

epoch  173  loss  4.112995579838753  time epoch  82.4874176979065


 88%|████████▊ | 175/200 [4:07:50<34:50, 83.62s/it]

epoch  174  loss  4.110492345690725  time epoch  82.43843483924866


 88%|████████▊ | 176/200 [4:09:13<33:18, 83.27s/it]

epoch  175  loss  4.0989005535841  time epoch  82.45656895637512


 88%|████████▊ | 177/200 [4:10:35<31:50, 83.05s/it]

epoch  176  loss  4.100954782962799  time epoch  82.5304388999939


 89%|████████▉ | 178/200 [4:11:58<30:23, 82.87s/it]

epoch  177  loss  4.102192309498787  time epoch  82.44217658042908


 90%|████████▉ | 179/200 [4:13:20<28:57, 82.76s/it]

epoch  178  loss  4.105202215909958  time epoch  82.50632810592651
epoch  179  loss  4.108557745814323  time epoch  82.45974564552307


 90%|█████████ | 180/200 [4:15:05<29:48, 89.43s/it]

eval mse  20.520081400676695  eval mae  180.29428405204166  eval ssim  0.6023670439632473  time=  22.486343145370483


 90%|█████████ | 181/200 [4:16:27<27:39, 87.33s/it]

epoch  180  loss  4.111550232768058  time epoch  82.43606233596802


 91%|█████████ | 182/200 [4:17:50<25:45, 85.87s/it]

epoch  181  loss  4.113076478242874  time epoch  82.47570872306824


 92%|█████████▏| 183/200 [4:19:12<24:02, 84.85s/it]

epoch  182  loss  4.116122376918794  time epoch  82.45768141746521


 92%|█████████▏| 184/200 [4:20:35<22:26, 84.15s/it]

epoch  183  loss  4.123499527573584  time epoch  82.50877857208252


 92%|█████████▎| 185/200 [4:21:57<20:54, 83.65s/it]

epoch  184  loss  4.1179613232612615  time epoch  82.48551321029663


 93%|█████████▎| 186/200 [4:23:20<19:26, 83.30s/it]

epoch  185  loss  4.114222162961961  time epoch  82.47043919563293


 94%|█████████▎| 187/200 [4:24:42<17:59, 83.07s/it]

epoch  186  loss  4.108787357807161  time epoch  82.55051445960999


 94%|█████████▍| 188/200 [4:26:05<16:34, 82.89s/it]

epoch  187  loss  4.111010414361953  time epoch  82.45599365234375


 94%|█████████▍| 189/200 [4:27:27<15:10, 82.75s/it]

epoch  188  loss  4.1100010693073274  time epoch  82.42895150184631
epoch  189  loss  4.125945496559142  time epoch  82.43234825134277


 95%|█████████▌| 190/200 [4:29:12<14:53, 89.39s/it]

eval mse  20.533290700573694  eval mae  180.33760778749576  eval ssim  0.6022884715673198  time=  22.410666704177856


 96%|█████████▌| 191/200 [4:30:35<13:05, 87.31s/it]

epoch  190  loss  4.114041548967361  time epoch  82.4724771976471


 96%|█████████▌| 192/200 [4:31:57<11:26, 85.84s/it]

epoch  191  loss  4.112463149428367  time epoch  82.4051923751831


 96%|█████████▋| 193/200 [4:33:20<09:53, 84.82s/it]

epoch  192  loss  4.120964458584785  time epoch  82.42832684516907


 97%|█████████▋| 194/200 [4:34:42<08:24, 84.12s/it]

epoch  193  loss  4.114839631319045  time epoch  82.50529336929321


 98%|█████████▊| 195/200 [4:36:04<06:58, 83.62s/it]

epoch  194  loss  4.119736760854722  time epoch  82.43307733535767


 98%|█████████▊| 196/200 [4:37:27<05:33, 83.29s/it]

epoch  195  loss  4.119505396485329  time epoch  82.51364064216614


 98%|█████████▊| 197/200 [4:38:49<04:09, 83.03s/it]

epoch  196  loss  4.12648673951626  time epoch  82.44149398803711


 99%|█████████▉| 198/200 [4:40:12<02:45, 82.89s/it]

epoch  197  loss  4.105794262886048  time epoch  82.56826877593994


100%|█████████▉| 199/200 [4:41:35<01:22, 82.78s/it]

epoch  198  loss  4.119596847891808  time epoch  82.52569556236267
epoch  199  loss  4.114895778894424  time epoch  82.46853470802307


100%|██████████| 200/200 [4:43:19<00:00, 85.00s/it]

eval mse  20.535118786559313  eval mae  180.34288546097875  eval ssim  0.6022884795168806  time=  22.400813817977905



