In [3]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [4]:
class ConvLSTM(nn.Module):
    def __init__(self,shape,in_channels,hidden_channels,kernel_size=3):
        super(ConvLSTM,self).__init__()
        
        self.shape = shape
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.conv = nn.Conv2d(in_channels+hidden_channels,4*hidden_channels,kernel_size,\
                              padding=kernel_size//2)
        
    def forward(self,input,hidden_state):
        hx,cx = hidden_state
        combined = torch.cat((input,hx),1)
        gates = self.conv(combined)
        
        ingate,forgetgate,cellgate,outgate = torch.split(gate,self.hidden_channels,dim=1)
        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)
        
        cy = (forgetgate*cx) + (ingate*cellgate)
        hy = outgate*torch.tanh(cy)
        
        return hy,cy
    
    def init_hidden(self,batch_size):
        return (Variable(torch.zeros(batch_size,self.hidden_channels,self.shape[0],self.shape[1])),
                Variable(torch.zeros(batch_size,self.hidden_channels,self.shape[0],self.shape[1])))
        

In [5]:
class ConvGRU(nn.Module):
    def __init__(self,shape,in_channels,hidden_channels,kernel_size):
        super(ConvGRU,self).__init__()
        
        self.shape = shape
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.conv1 = nn.Conv2d(self.in_channels+self.hidden_channels,2*self.hidden_channels,kernel_size,\
                              1,kernel_size//2)
        self.conv2 = nn.Conv2d(self.in_channels+self.hidden_channels,self.hidden_channels,kernel_size,\
                              1,kernel_size//2)
        
    def forward(self,input,hidden_state):
        htprev = hidden_state
        combined_1 = torch.cat((input,htprev),1)
        gates = self.conv1(combined_1)
        
        zgate,rgate = torch.split(gate,self.hidden_channels,1)
        z = torch.sigmoid(zgate)
        r = torch.sigmoid(rgate)
        
        combined_2 = torch.cat((input, r*htprev), 1)
        ht = self.conv2(combined_2)
        ht = torch.tanh(ht)
        htnext = (1-z)*htprev + z*ht
        
        return htnext
    
    def init_hidden(self,batch_size):
        return Variable(torch.zeros(batch_size,self.hidden_channels,self.shape[0],self.shape[1]))
    

In [6]:
class CRNN(nn.Module):
    def __init__(self,shape,in_channels,kernel_size,hidden_channels,num_layers,cell='ConvLSTM'):
        super(CRNN,self).__init__()
        
        self.shape = shape
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.cell = cell
        
        cell_list = []
        
        if self.cell == 'ConvGRU':
            cell_list.append(ConvGRU(self.shape,self.in_channels,self.kernel_size,self.in_channels))
            for idcell in range(1,self.num_layers):
                cell_list.append(ConvGRU(self.shape,self.hidden_channels,self.kernel_size,self.hidden_channels))
            self.cell_list = nn.ModuleList(cell_list)
        
        else:
            cell_list.append(ConvLSTM(self.shape,self.in_channels,self.kernel_size,self.in_channels))
            for idcell in range(1,self.num_layers):
                cell_list.append(ConvLSTM(self.shape,self.hidden_channels,self.kernel_size,self.hidden_channels))
            self.cell_list = nn.ModuleList(cell_list)       
            
    def forward(self,input,hidden_state):
        current_input = input
        next_hidden = []
        seq_len = current_input.size(0)
        
        for idlayer in range(self.num_layers):
            hidden_c = hidden_state[idlayer]
            output_inner = []
            
            for t in range(seq_len):
                hidden_c = self.cell_list[idlayer](current_input[t,:,:,:,:])
                if self.cell=='ConvLSTM':
                    output_inner.append(hidden_c[0])
                else:
                    output_inner.append(hidden_c)
                
            next_hidden.append(hidden_c)
            if self.cell == 'ConvLSTM':
                current_input = torch.cat(output_inner,0).view(seq_len,*output_inner[0].size())
            else:
                current_input = torch.cat(output_inner,0).view(seq_len,*output_inner[0].size())
                
    def init_hidden(self,batch_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size))
        return init_states