In [1]:
from skimage.metrics import structural_similarity as ssim
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import random
import time
from PhyDNet.models.models import ConvLSTM,PhyCell, EncoderRNN
from PhyDNet.data.moving_mnist import MovingMNIST
from PhyDNet.constrain_moments import K2M
import argparse
from tqdm import tqdm


In [2]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# parser = argparse.ArgumentParser()
# parser.add_argument('--root', type=str, default='data/')
# parser.add_argument('--batch_size', type=int, default=16, help='batch_size')
# parser.add_argument('--nepochs', type=int, default=2001, help='nb of epochs')
# parser.add_argument('--print_every', type=int, default=1, help='')
# parser.add_argument('--eval_every', type=int, default=10, help='')
# parser.add_argument('--save_name', type=str, default='phydnet', help='')
# args = parser.parse_args()

batch_size=32
eval_every=1
print_every=1
nepochs=1
data_range = 1.0 # data range 0 to 1 - images normalized this way
root='PhyDNet/data/'
save_name='phydnet'

In [3]:
from torch.utils.data import random_split

mm = MovingMNIST(root=root, is_train=False, n_frames_input=10, n_frames_output=10, num_objects=[2])
# Split ratio for train and test
train_ratio = 0.8
test_ratio = 0.2

# Calculate the lengths of train and test sets
train_length = int(train_ratio * len(mm))
test_length = len(mm) - train_length

# Split the dataset
train_dataset, test_dataset = random_split(mm, [train_length, test_length])


In [4]:

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)


test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

constraints = torch.zeros((49,7,7)).to(device)
ind = 0
for i in range(0,7):
    for j in range(0,7):
        constraints[ind,i,j] = 1
        ind +=1    


In [5]:

def train_on_batch(input_tensor, target_tensor, encoder, encoder_optimizer, criterion,teacher_forcing_ratio):                
    encoder_optimizer.zero_grad()
    # input_tensor : torch.Size([batch_size, input_length, channels, cols, rows])
    input_length  = input_tensor.size(1)
    target_length = target_tensor.size(1)
    loss = 0
    for ei in range(input_length-1): 
        encoder_output, encoder_hidden, output_image,_,_ = encoder(input_tensor[:,ei,:,:,:], (ei==0) )
        loss += criterion(output_image,input_tensor[:,ei+1,:,:,:])

    decoder_input = input_tensor[:,-1,:,:,:] # first decoder input = last image of input sequence
    
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False 
    for di in range(target_length):
        decoder_output, decoder_hidden, output_image,_,_ = encoder(decoder_input)
        target = target_tensor[:,di,:,:,:]
        loss += criterion(output_image,target)
        if use_teacher_forcing:
            decoder_input = target # Teacher forcing    
        else:
            decoder_input = output_image

    # Moment regularization  # encoder.phycell.cell_list[0].F.conv1.weight # size (nb_filters,in_channels,7,7)
    k2m = K2M([7,7]).to(device)
    for b in range(0,encoder.phycell.cell_list[0].input_dim):
        filters = encoder.phycell.cell_list[0].F.conv1.weight[:,b,:,:] # (nb_filters,7,7)     
        m = k2m(filters.double()) 
        m  = m.float()   
        loss += criterion(m, constraints) # constrains is a precomputed matrix   
    loss.backward()
    encoder_optimizer.step()
    return loss.item() / target_length


def trainIters(encoder, nepochs, print_every=1,eval_every=1,name=''):
    train_losses = []
    best_mse = float('inf')

    encoder_optimizer = torch.optim.Adam(encoder.parameters(),lr=0.001)
    scheduler_enc = ReduceLROnPlateau(encoder_optimizer, mode='min', patience=2,factor=0.1,verbose=True)
    criterion = nn.MSELoss()
    loss_epoch = 0
    for epoch in range(0, nepochs):
        t0 = time.time()
        loss_epoch = 0
        teacher_forcing_ratio = np.maximum(0 , 1 - epoch * 0.003) 
        i=0
        for out in tqdm(train_loader, desc='train'):
            input_tensor = out[1].to(device)
            target_tensor = out[2].to(device)
            loss = train_on_batch(input_tensor, target_tensor, encoder, encoder_optimizer, criterion, teacher_forcing_ratio)                                   
            loss_epoch += loss
            i+=1
                      
        train_losses.append(loss_epoch)        
        if (epoch+1) % print_every == 0:
            print('epoch ',epoch,  ' loss ',loss_epoch, ' time epoch ',time.time()-t0)
            
        if (epoch+1) % eval_every == 0:
            mse, mae,ssim = evaluate(encoder,test_loader) 
            scheduler_enc.step(mse)                   
            torch.save(encoder.state_dict(),'PhyDNet/save/encoder_{}.pth'.format(name))                           
    return train_losses

    
def evaluate(encoder,loader):
    total_mse, total_mae,total_ssim,total_bce = 0,0,0,0
    t0 = time.time()
    with torch.no_grad():
        i=0
        for out in tqdm(loader, desc=f'eval. mse: {total_mse/len(loader):.2f}, mae: {total_mae/len(loader):.2f}, ssim: {total_ssim/len(loader):.2f}'):

            input_tensor = out[1].to(device)
            target_tensor = out[2].to(device)
            input_length = input_tensor.size()[1]
            target_length = target_tensor.size()[1]

            for ei in range(input_length-1):
                encoder_output, encoder_hidden, _,_,_  = encoder(input_tensor[:,ei,:,:,:], (ei==0))

            decoder_input = input_tensor[:,-1,:,:,:] # first decoder input= last image of input sequence
            predictions = []

            for di in range(target_length):
                decoder_output, decoder_hidden, output_image,_,_ = encoder(decoder_input, False, False)
                decoder_input = output_image
                predictions.append(output_image.cpu())

            input = input_tensor.cpu().numpy()
            target = target_tensor.cpu().numpy()
            predictions =  np.stack(predictions) # (10, batch_size, 1, 64, 64)
            predictions = predictions.swapaxes(0,1)  # (batch_size,10, 1, 64, 64)

            mse_batch = np.mean((predictions-target)**2 , axis=(0,1,2)).sum()
            mae_batch = np.mean(np.abs(predictions-target) ,  axis=(0,1,2)).sum() 
            total_mse += mse_batch
            total_mae += mae_batch
            
            for a in range(0,target.shape[0]):
                for b in range(0,target.shape[1]):
                    total_ssim += ssim(target[a,b,0,], predictions[a,b,0,], data_range=data_range) / (target.shape[0]*target.shape[1]) 

            
            cross_entropy = -target*np.log(predictions) - (1-target) * np.log(1-predictions)
            cross_entropy = cross_entropy.sum()
            cross_entropy = cross_entropy / (batch_size*target_length)
            total_bce +=  cross_entropy
            i+=1
     
    print('eval mse ', total_mse/len(loader),  ' eval mae ', total_mae/len(loader),' eval ssim ',total_ssim/len(loader), ' time= ', time.time()-t0)        
    return total_mse/len(loader),  total_mae/len(loader), total_ssim/len(loader)


phycell  =  PhyCell(input_shape=(16,16), input_dim=64, F_hidden_dims=[49], n_layers=1, kernel_size=(7,7), device=device) 
convcell =  ConvLSTM(input_shape=(16,16), input_dim=64, hidden_dims=[128,128,64], n_layers=3, kernel_size=(3,3), device=device)   
encoder  = EncoderRNN(phycell, convcell, device)
  
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
   
print('phycell ' , count_parameters(phycell) )    
print('convcell ' , count_parameters(convcell) ) 
print('encoder ' , count_parameters(encoder) ) 

trainIters(encoder,nepochs,print_every=print_every,eval_every=1,name=save_name)

encoder.load_state_dict(torch.load('PhyDNet/save/encoder_phydnet.pth'))
encoder.eval()
mse, mae,ssim = evaluate(encoder,test_loader) 


layer  0 input dim  64  hidden dim  128
layer  1 input dim  128  hidden dim  128
layer  2 input dim  128  hidden dim  64
phycell  230803
convcell  2508032
encoder  3091732


train:   0%|          | 0/250 [00:00<?, ?it/s]

train:  22%|██▏       | 56/250 [01:28<05:07,  1.58s/it] 


KeyboardInterrupt: 

In [None]:

next(iter(test_loader))[2].shape

torch.Size([32, 10, 1, 64, 64])