In [5]:
import math 

import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import Dataset, random_split, DataLoader
from typing import Any, Callable, Optional, Tuple

import pytorch_lightning as pl

from omegaconf import OmegaConf


import matplotlib.pyplot as plt

from omegaconf import OmegaConf

In [6]:
# Define Model

class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=None, name='Convolutional Encoder'):
        super(Encoder, self).__init__()
        
        self.id = 'encoder'
        
        if output_size==None:
            output_size=input_size
            
        self.encoder = nn.Sequential(
            nn.Conv2d(input_size,hidden_size, 3, padding=(1,1)),
            nn.ReLU(),
            nn.BatchNorm2d(hidden_size),
            
            nn.Conv2d(hidden_size,hidden_size, 3, padding=(1,1), stride=(1,1)),
            nn.ReLU(),
            nn.BatchNorm2d(hidden_size),
                                
            nn.Conv2d(hidden_size,output_size, 3, padding=(1,1), stride=(2,2)),
            nn.Sigmoid()
        )
        
    def forward(self, input):
        return self.encoder(input) 
    
class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=None, name='Convolutional Decoder'):
        super(Decoder, self).__init__()
        
        self.id = 'decoder'
        
        if output_size==None:
            output_size=input_size       
        
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(input_size,hidden_size, 3, padding=(1,1)),
            nn.ReLU(),
            nn.BatchNorm2d(hidden_size),
            
            nn.Conv2d(hidden_size,hidden_size, 3, padding=(1,1)),
            nn.ReLU(),
            nn.BatchNorm2d(hidden_size),
            
            nn.Conv2d(hidden_size,output_size, 3, padding=(1,1)),
            nn.Sigmoid()        
        )
        
    def forward(self, input):
        return self.decoder(input) 
 
class TimeDistributed(nn.Module):
    def __init__(self, module, name='Time Distributed'):
        super(TimeDistributed, self).__init__()
        self.module = module
        
    def forward(self, input):
        batch_or_time1, batch_or_time2 = input.size(0), input.size(1)
        
        new_shape = list([batch_or_time1*batch_or_time2]) + list(input.shape[2:])
        input = self.module(input.reshape(new_shape))
        
        output_shape = list([batch_or_time1, batch_or_time2]) + list(input.shape[1:])
        return input.reshape(output_shape)  

class CLSTM_cell(nn.Module):
    def __init__(self, input_size, hidden_size, kernel_size=(3,3), padding=(1,1), name='CLSTM-Cell'):
        
        super(CLSTM_cell, self).__init__()      
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        self.conv = nn.Conv2d(input_size + hidden_size, hidden_size*4, kernel_size, 1, padding)
        self.reset_parameters()
        
    def forward(self, input, H=None, C=None):
        self.check_forward_input(input)
        if H is None:
            H = torch.zeros(input.size(0), self.hidden_size, input.size(2), input.size(3),
                             dtype=input.dtype, device=input.device)
        if C is None:
            C = torch.zeros(input.size(0), self.hidden_size, input.size(2), input.size(3),
                             dtype=input.dtype, device=input.device)  

        conv_out = self.conv(torch.cat([input,H], dim=1)) # concatenate the features, [b, f_X+f_h, :, :] # [b, f_h,:,:]
        _f, _i, _c, _o = torch.split(conv_out, self.hidden_size, dim=1)
        
        f = torch.sigmoid(_f)
        i = torch.sigmoid(_i)
        c = torch.sigmoid(_c)
        o = torch.tanh(_o)
        
        C = C * f + i * c
        H = o * torch.tanh(c)
        return H, C

    def reset_parameters(self) -> None:
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            nn.init.uniform_(weight, -stdv, stdv)

    def check_forward_input(self, input: Tensor) -> None:
        if input.size(1) != self.input_size:
            raise RuntimeError(
                "input has inconsistent input_size: got {}, expected {}".format(
                    input.size(1), self.input_size))                     

class CLSTM_layers(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=3, name='CLSTM-Layers'):
        super(CLSTM_layers, self).__init__()
            
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        cell_0 = [CLSTM_cell(input_size, hidden_size)]
        self.cells = nn.ModuleList(cell_0 + [CLSTM_cell(hidden_size, hidden_size) for i in range(1,num_layers)])
        
    def getShapedTensor(self, input, repeat_input=False):
        if repeat_input:
            return torch.zeros(input.size(0), self.hidden_size, input.size(2), input.size(3), dtype=input.dtype, device=input.device)
        else:
            return torch.zeros(input.size(0), self.hidden_size, input.size(3), input.size(4), dtype=input.dtype, device=input.device)

    def forward(self, input, H=None, C=None, return_sequence=False, repeat_input=False, output_sequence_length=8):
        # if return sequence: return all h's, 
        # else: returns last H and C
            
        if repeat_input:
            Hs = [self.getShapedTensor(input, True) for _ in range(output_sequence_length)]  
        else:
            Hs = [self.getShapedTensor(input) for _ in range(input.size(1))]  
        
        if not isinstance(H, type(None)):
            Hs[0] = H
        if isinstance(C, type(None)):
            C = self.getShapedTensor(input, repeat_input)
                
        if repeat_input: 
            for t in range(output_sequence_length):
                H, C = self.cells[0](input, H=H, C=C)
                Hs[t] = H
                
            for l in range(1, len(self.cells)):  
                for t in range(output_sequence_length):
                    if t==0:
                        C = self.getShapedTensor(input, repeat_input)
                    H, C = self.cells[l](Hs[t], H=H, C=C)
                    Hs[t] = H
        else:
            for t in range(input.size(1)): # numer of time-steps
                H, C = self.cells[0](input[:,t], H=H, C=C)
                Hs[t] = H
                
            for l in range(1, len(self.cells)):  
                for t in range(input.size(1)):
                    if t==0:
                        C = self.getShapedTensor(input, repeat_input)
                    H, C = self.cells[l](Hs[t], H=H, C=C)
                    Hs[t] = H  
        
        if return_sequence:
            return torch.stack(Hs).transpose(0,1) # [b,t,hd,:,:]
        #print(Hs[-1].shape)
        return Hs[-1], C# [b,hd,:,:], [b,hd,:,:], [b,hd,:,:]

class CLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=None, directional=1, name='CLSTM'):
        super(CLSTM, self).__init__()
        
        self.id = 'CLSTM'
        
        clstm_hidden_size=hidden_size
        
        if output_size==None:
            output_size = input_size

        self.encoder = TimeDistributed(Encoder(input_size, hidden_size, output_size=hidden_size))
        self.decoder = TimeDistributed(Decoder(clstm_hidden_size, hidden_size, output_size=output_size))
        
        self.encoder_clstm = CLSTM_layers(hidden_size,clstm_hidden_size, num_layers=3)      
        self.decoder_clstm = CLSTM_layers(clstm_hidden_size, clstm_hidden_size)


    def forward(self, input, max_depth=10):
        input = self.encoder(input)
        H, C = self.encoder_clstm(input.flip(1), return_sequence=False)   
        input = self.decoder_clstm(H, H=H, C=C, return_sequence=True, repeat_input=True, output_sequence_length=max_depth)
        input = self.decoder(input)
        return input.transpose(1,2)


In [7]:
# Define Dataset

class HeartStaticDataset(Dataset):
    def __init__(self, config, sim_name='Scroll1', kind='V', mode='projection'):
        self.config = config 
        
        def atoi(text):
            return int(text) if text.isdigit() else text

        def natural_keys(text):
            return [atoi(c) for c in re.split(r'(\d+)', text)]
        
        if config.kind=='V':
            names = 'V_snap*'
        else:
            names = 'C_snap*'
            
        self.mode = mode
            
        basename = os.path.join(config.data_folder, names)
        self.files = glob.glob(basename)
        self.files.sort(key=natural_keys)
        
        self.surface = np.fromfile(self.config.surface_file, dtype=np.int32)
        #self.surface_inds = self.List2Vec(surface)
        
        self.outer_surface = np.load(self.config.outer_surface_file)
        self.outer_surface_inds = self.List2Vec(self.outer_surface)
        

    def __len__(self):
        return 300

    def __getitem__(self, idx):
        # mode can be: 'load_file', 'surface', 'projection'
        #return np.load(self.files[idx])
        
        # 1: Load file
        data = np.fromfile(self.files[idx], dtype=np.double)
        if self.config.mode=='load_file':
            return data
        
        # 2: Extract points to surface inds
        dense_pc = self.List2Vec(data*self.Vec2List(self.outer_surface_inds), return_values=True)
        if self.config.mode=='surface':
            return dense_pc
        
        # 3: Project to
        soff = np.mean(dense_pc[:,:3], axis=0)
        dense_pc_projected = self.get_spherical_projection(dense_pc[:,:3], *soff, 300)
            
        if self.config.mode=='projection':
            return dense_pc_projected

        return dense_pc
    
    def List2Vec(self, lst, return_values=False):
        vecs = np.reshape(lst, self.config.shape)
        pos = np.array(np.where(vecs!=0)).T.astype(np.int16)

        if return_values==False:
            return pos

        values = np.array([vecs[pos[:,0], pos[:,1],pos[:,2]]]).T
        return np.concatenate([pos, values], axis=1)

    def Vec2List(self, vecs):
        cube = np.zeros(self.config.shape)
        cube[vecs[:,0], vecs[:,1], vecs[:,2]] = 1
        lst = np.reshape(cube, -1)
        return lst
    
    def get_spherical_projection(self, coords,x0,y0,z0,r1):
        x2 = coords[:,0]
        y2 = coords[:,1]
        z2 = coords[:,2]
        x0 = x0 * np.ones(x2.size)
        y0 = y0 * np.ones(y2.size)
        z0 = z0 * np.ones(z2.size)
        r1 = r1 * np.ones(x2.size)
        r2 = np.power(np.square(x2-x0)+np.square(y2-y0)+np.square(z2-z0),0.5)
        x_new = x0 + (r1/r2) * (x2-x0)
        y_new = y0 + (r1/r2) * (y2-y0)
        z_new = z0 + (r1/r2) * (z2-z0)
        return np.array([x_new,y_new,z_new]).T

class HeartDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        
    def prepare_data(self):
        pass

    def setup(self, stage=None):
        dataset = BarkleyDataset(self.config.dataset_dir, 
                                 time_steps=self.config.time_step, 
                                 depth=self.config.depth)
        print(dataset)
        
        n_train = int(len(dataset)*0.95+0.5)
        n_val = int(len(dataset)*0.05+0.5)
        
        self.train_dataset, self.val_dataset = random_split(dataset, [n_train, n_val])   
        
    def train_dataloader(self):
        train_loader = DataLoader(self.train_dataset, batch_size=self.config.batch_size, num_workers=4, shuffle=True)
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(self.val_dataset, batch_size=self.config.batch_size, num_workers=4, shuffle=False)
        return val_loader

    def test_dataloader(self):
        test_loader = DataLoader(self.val_dataset, batch_size=self.config.batch_size, num_workers=4, shuffle=False)
        return test_loader

In [None]:
dataset = 

In [None]:
# Define Training