<a href="https://colab.research.google.com/github/AjayKumarGogineni777/ConvLSTM/blob/master/Conv_LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [0]:

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from skimage import io
import os
import scipy.ndimage as ndimage
import matplotlib.pyplot as plt
import scipy.misc
import cv2

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
import torch
import torch.nn as nn
from torch.autograd import Variable


class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()

        #assert hidden_channels % 2 == 0

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_features = 4

        self.padding = int((kernel_size - 1) / 2)

        self.Wxi = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whi = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxf = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whf = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxc = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whc = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxo = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Who = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)

        self.Wci = None
        self.Wcf = None
        self.Wco = None

    def forward(self, x, h, c):
        ci = torch.sigmoid(self.Wxi(x) + self.Whi(h) + c * self.Wci)
        cf = torch.sigmoid(self.Wxf(x) + self.Whf(h) + c * self.Wcf)
        cc = cf * c + ci * torch.tanh(self.Wxc(x) + self.Whc(h))
        co = torch.sigmoid(self.Wxo(x) + self.Who(h) + cc * self.Wco)
        ch = co * torch.tanh(cc)
        return ch, cc

    def init_hidden(self, batch_size, hidden, shape):
        if self.Wci is None:
            self.Wci = Variable(torch.zeros(1, hidden, shape[0], shape[1]))## cuda removed
            self.Wcf = Variable(torch.zeros(1, hidden, shape[0], shape[1]))## cuda removed
            self.Wco = Variable(torch.zeros(1, hidden, shape[0], shape[1]))## cuda removed
        else:
            assert shape[0] == self.Wci.size()[2], 'Input Height Mismatched!'
            assert shape[1] == self.Wci.size()[3], 'Input Width Mismatched!'
        return (Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])),## cuda removed
                Variable(torch.zeros(batch_size, hidden, shape[0], shape[1]))) ## cuda removed


class ConvLSTM(nn.Module):
    # input_channels corresponds to the first input feature map
    # hidden state is a list of succeeding lstm layers.
    def __init__(self, input_channels, hidden_channels, kernel_size, step=1, effective_step=[1]):
        super(ConvLSTM, self).__init__()
        self.input_channels = [input_channels] + hidden_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_layers = len(hidden_channels)
        self.step = step
        self.effective_step = effective_step
        self._all_layers = []
        for i in range(self.num_layers):
            name = 'cell{}'.format(i)
            cell = ConvLSTMCell(self.input_channels[i], self.hidden_channels[i], self.kernel_size)
            setattr(self, name, cell)
            self._all_layers.append(cell)
            #self.fc = nn.Linear(in_features = self.hidden_channels[i], out_features = 3) 
        self.fc = nn.Linear(in_features = 32, out_features = 3)    
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        internal_state = []
        outputs = []
        
        final_outs = []
        for step in range(self.step):
            x = input
            for i in range(self.num_layers):
                # all cells are initialized in the first step
                name = 'cell{}'.format(i)
                if step == 0:
                    bsize, _, height, width = x.size()
                    (h, c) = getattr(self, name).init_hidden(batch_size=bsize, hidden=self.hidden_channels[i],
                                                             shape=(height, width))
                    internal_state.append((h, c))

                # do forward
                (h, c) = internal_state[i]
                x, new_c = getattr(self, name)(x, h, c)
                internal_state[i] = (x, new_c)
                foutput = self.fc(x.contiguous().view(-1, 32))
                #foutput = self.relu(foutput)
                
                foutput = self.sigmoid(foutput)-0.5
                #print(foutput.shape)
                #print(x.shape)
                foutput = foutput.reshape((5, 3, 128, 128))
                
            # only record effective steps
            if step in self.effective_step:
                outputs.append(x)
                final_outs.append(foutput)

        return final_outs, outputs, (x, new_c)




In [0]:
### Custom Cross Entropy for LSTM outputs
#import torch.log as log

import torch.nn.functional as F
def lstm_loss_mod(pred,target):
  X = 128
  Y = 128
  Z = 3
  loss = 0
  eps = 0.00001
  bs,ch,height,width = pred.shape
  for b in range(bs):
    for c in range(ch):
      for h in range(height):
        for w in range(width):
          P = pred[b,c,h,w]
          T = target[b,c,h,w]
          a1 = torch.log(torch.abs(P+eps))
          a2 = torch.log(torch.abs(1-P+eps))
          #loss += T*a1 + (1-T)*a2
          #loss += T*P + (1-T)*(1-P)
          #loss += torch.norm(T - P, 2)
          
          b1 = torch.sigmoid(P)
          #b2 = torch.sigmoid(T)
          loss += T*torch.log(b1) + (1-T)*torch.log(1 - b1)
          
  return -1*loss




##### Complete training loop

lr = 1e-4
T = 5
    
    
convlstm = ConvLSTM(input_channels=3, hidden_channels=[32, 32], kernel_size=5, step=5,
                        effective_step=[4])

opt = torch.optim.RMSprop(convlstm.parameters(), lr=lr, weight_decay=1e-5) 

#### Loading previously saved model
convlstm.load_state_dict(torch.load('/content/gdrive/My Drive/ConvLSTM/models/convlstm_30_epochs'))

epochs = 6

import pickle
with open('/content/gdrive/My Drive/ConvLSTM/train_list_x/tr_x_pkl_all', "rb") as fp:
  b = pickle.load(fp)
#with open('/content/gdrive/My Drive/ConvLSTM/train_list_y/tr_y_pkl_all', "rb") as fc:
  #c = pickle.load(fc)

for epoch in range(epochs):
  for i in range(10):
    x_complete = b
    #y_complete = c
    x = x_complete[i]


##### Importnat to check

    #y = y_complete[i]


    y = x_complete[i+1]


    input = (Variable(x).float()/255)-0.5
    target = (y.double()/255)-0.5

    output = convlstm(input)


    output = output[0][0].double()
    loss = lstm_loss_mod(output,target)
#res = torch.autograd.gradcheck(loss_fn, (output, target), eps=1e-6, raise_exception=True)
#print(res)
    opt.zero_grad()
    loss.backward()
    opt.step()
  print(loss)
  print(' > Epoch {:2d} loss: {:.3f}'.format((epoch+1), loss.item()))
  

tensor(165003.1011, dtype=torch.float64, grad_fn=<MulBackward0>)
 > Epoch  1 loss: 165003.101
tensor(164889.2140, dtype=torch.float64, grad_fn=<MulBackward0>)
 > Epoch  2 loss: 164889.214
tensor(164857.5150, dtype=torch.float64, grad_fn=<MulBackward0>)
 > Epoch  3 loss: 164857.515


In [0]:
#### Saving the model


st = convlstm.state_dict()
mdl_path = '/content/gdrive/My Drive/ConvLSTM/models/'
#torch.save(st, mdl_path+'/50_items_adam_24_epochs')
torch.save(st, mdl_path+'/convlstm_36_epochs')

In [0]:
### x_complete is input with first element as a series of images from index 0 to index 5

### y_complete is output with first element as a series of images from index 1 to index 6, i.e. shifted by one index w.r.t x_complete

x_complete = []
y_complete = []

#N = 40
#s = 8

N = 309
s = 13
T = 5
for num in range(N):
  for t in range(s):
    x = []
    y = []
    for i in range(t, t+T, 1):
      i = i+1
      ima = io.imread('/content/gdrive/My Drive/Satellite/ConvLSTM/train/%d/%d.tiff'%(i,num))
      ima = np.rollaxis(ima, -1, 0)
      imb = io.imread('/content/gdrive/My Drive/Satellite/ConvLSTM/train/%d/%d.tiff'%(i+1, num))
      imb = np.rollaxis(imb, -1, 0)
      x.append(ima)
      y.append(imb)
     
    x = torch.from_numpy(np.array(x)).float()
    y = torch.from_numpy(np.array(y)).float()
    x_complete.append(x)
    y_complete.append(y)

In [0]:
##### Creation of Data and storing the data in pickle files
## x_complete is input and y_complete is output

        
import pickle

with open('/content/gdrive/My Drive/Satellite/ConvLSTM/train_list_x/tr_x_pkl_all', 'wb') as fp:
    pickle.dump(x_complete, fp)   
    

with open('/content/gdrive/My Drive/Satellite/ConvLSTM/train_list_y/tr_y_pkl_all', 'wb') as fy:
    pickle.dump(y_complete, fy)      
        

In [0]:
### output after few epochs
output

tensor([[[[-0.0070, -0.0005, -0.0145,  ..., -0.0226, -0.0130, -0.0040],
          [-0.0226, -0.0127, -0.0046,  ..., -0.0040, -0.0228, -0.0130],
          [-0.0040, -0.0226, -0.0130,  ..., -0.0127, -0.0046, -0.0227],
          ...,
          [-0.3661, -0.4300, -0.3579,  ..., -0.3613, -0.3600, -0.4309],
          [-0.3049, -0.3052, -0.3727,  ..., -0.4300, -0.3579, -0.3661],
          [-0.4300, -0.3760, -0.3126,  ..., -0.3052, -0.3727, -0.3579]],

         [[-0.3661, -0.4300, -0.2010,  ...,  0.0180,  0.0272,  0.0211],
          [-0.3049, -0.3052, -0.3727,  ..., -0.4305,  0.0227,  0.0323],
          [ 0.0289,  0.0136,  0.0233,  ..., -0.3052, -0.3727, -0.3632],
          ...,
          [-0.4445, -0.3802, -0.3814,  ..., -0.3847, -0.4432, -0.3776],
          [-0.3859, -0.4445, -0.3776,  ..., -0.3802, -0.3814, -0.4441],
          [-0.3756, -0.3847, -0.4432,  ..., -0.4445, -0.3776, -0.3859]],

         [[-0.4445, -0.3802, -0.3814,  ..., -0.3847, -0.4432, -0.3776],
          [-0.3859, -0.4445, -

In [0]:
## The target tensor
target

tensor([[[[ 0.5000,  0.5000,  0.5000,  ...,  0.5000,  0.5000,  0.5000],
          [ 0.5000,  0.5000,  0.5000,  ...,  0.5000,  0.5000,  0.5000],
          [ 0.5000,  0.5000,  0.5000,  ...,  0.5000,  0.5000,  0.5000],
          ...,
          [ 0.5000,  0.5000,  0.5000,  ...,  0.3627,  0.3392,  0.3627],
          [ 0.5000,  0.5000,  0.5000,  ...,  0.3431,  0.3392,  0.3784],
          [ 0.5000,  0.5000,  0.5000,  ...,  0.3314,  0.3196,  0.3745]],

         [[ 0.5000,  0.5000,  0.5000,  ...,  0.5000,  0.5000,  0.5000],
          [ 0.5000,  0.5000,  0.5000,  ...,  0.5000,  0.5000,  0.5000],
          [ 0.5000,  0.5000,  0.5000,  ...,  0.5000,  0.5000,  0.5000],
          ...,
          [ 0.5000,  0.5000,  0.5000,  ..., -0.1667, -0.1510, -0.2255],
          [ 0.5000,  0.5000,  0.5000,  ..., -0.1627, -0.1627, -0.2059],
          [ 0.5000,  0.5000,  0.5000,  ..., -0.1588, -0.1824, -0.1941]],

         [[ 0.5000,  0.5000,  0.5000,  ...,  0.5000,  0.5000,  0.5000],
          [ 0.5000,  0.5000,  