---
title: "Understanding ConvLSTM"
author: "Kirtan Gangani"
date: "July 13, 2025" 
categories: [CNN, LSTM]
format:
  html:
    toc: true
    code-fold: false
    code-copy: true
jupyter: python3
image: "convlstm.png"
---

# Introduction

In the world of deep learning, we've witnessed incredible breakthroughs. **Convolutional Neural Networks (CNNs)** have revolutionized the way we interpret **images and spatial data**, while **Recurrent Neural Networks (RNNs)**, particularly **Long Short-Term Memory (LSTM) networks**, have become indispensable for understanding **sequential information** like text and time series.

But what happens when your data is both spatial and sequential? Think about video: it's a sequence of images. How do you analyze dynamic patterns that evolve in both space and time? This is where traditional networks often struggle, and it's precisely the challenge that ConvLSTM was designed to address.

In this blog post, we will understand the fundamental building blocks, understand their individual limitations, and then the solution that ConvLSTM offers for understanding dynamic data.

# The Building Blocks: CNNs and LSTMs

Let's first quickly revisit its two powerful parents.

## CNNs

CNNs are a specialized type of neural network designed for **processing grid-like data**. This makes them exceptionally good at handling images, video frames, and other spatial datasets. Their power comes from their ability to automatically **learn** hierarchical spatial features **directly from raw input** (e.g., pixels). This is done in two steps:  
1. Early layers detect **specific local patterns** like edges, corners and textures.  
2. Deeper layers then combine these to **recognize complex objects and patterns**.

![](cnn.png)

: Figure 1: A conceptual diagram of a Convolutional Neural Network, showing how spatial features are extracted and downsampled through convolutional and pooling layers.

As you can see in Figure 1, CNNs apply **filters** (kernels) across the input to create **feature maps**, progressively reducing the spatial dimensions while increasing feature complexity. This architecture is inherently designed to understand where features are located in space.

## LSTMs

LSTMs are a special type of Recurrent Neural Network (RNN) specifically designed to **process sequential data**. They were introduced to **mitigate** common problems in standard RNNs, such as the **vanishing or exploding gradient problem**, which made it difficult for RNNs to learn long-term dependencies.

The core of an LSTM's power lies in its Cell State (Memory) and its "gates":

* **Cell State (Memory)**: Imagine this as a "conveyor belt" for information, running through the entire sequence. It carries information across time steps, allowing the network to retain relevant data for long periods.
* **Gates**: Three specialized "gates" control the flow of information into and out of the cell state:
    * **Forget Gate**: Decides what information to discard from the previous cell state.
    * **Input Gate**: Determines how much of the new candidate information (derived from the current input and previous hidden state) should be added to the cell state.
    * **Output Gate**: Controls how much of the current cell state will contribute to the hidden state, which then serves as the output for the current time step and input for the next.

![](lstm-cell.png)

: Figure 2: The internal structure of a Long Short-Term Memory (LSTM) cell.

Figure 2 visually represents these gates and how they interact to selectively update and output information, making LSTMs adept at **remembering crucial details over long sequences**.

# The Limitations of Individual Networks

While **CNNs and LSTMs** are incredibly powerful in their respective domains, they have significant **limitations** when dealing with data that exhibits both strong spatial and temporal dependencies.

## CNNs: Good at Space but Bad at Time

Standard CNNs, while **excellent at extracting features** from individual images or frames but they **do not inherently capture temporal dependencies between consecutive frames.** When you feed a sequence of images (like a video) into a standard CNN, it treats each frame as an independent input. This means it can recognize objects within each frame, but it loses crucial information about motion, changes, or the sequence of events over time. For example, a CNN could identify a car in two consecutive frames, but it wouldn't inherently understand that the car is moving or how it's moving.

## LSTMs: Good at Time but Bad at Space

Traditional LSTMs are **excellent for sequences**, but they **expect a 1D vector** as input at each time step. This presents a major challenge for spatial data like images or video frames. To feed an image (e.g., a 64x64 pixel grayscale image) into a standard LSTM, you would first have to **"flatten" (unroll)** its 2D grid into a long 1D vector (e.g., 4096 pixels).

This "flattening" process leads to two significant problems:

1. **Loss of Spatial Relationships**: When you flatten an image, you destroy the crucial spatial relationships between pixels. Pixels that were close together in the 2D grid become distant in the 1D vector. The LSTM then loses the ability to recognize patterns based on proximity, adjacency, or overall shape – the very strengths of CNNs.
2. **High Dimensionality**: For higher-resolution images or colored images, flattening can lead to extremely long input vectors, making the LSTM computationally expensive, prone to overfitting, and harder to train effectively.

# The Solution: ConvLSTM

Recognizing the limitations of relying solely on CNNs for temporal tasks or LSTMs for spatial tasks, ConvLSTM emerged as a powerful hybrid architecture. It integrates:

* **Convolutional Neural Networks (CNNs)**: To expertly extract spatial features from grid-like data (like images).
* **Long Short-Term Memory (LSTMs)**: To learn and remember long-term dependencies in sequential data.

**ConvLSTM** was specifically developed to **process data that has both spatial and temporal dimensions.** simultaneously. This makes it ideal for sequential data where each data point is a multi-dimensional grid, such as:

* Video frames
* Weather radar data
* Medical imaging sequences (e.g., fMRI, dynamic MRI)

Crucially, ConvLSTM **maintains the spatial information** throughout its recurrent connections, unlike traditional LSTMs that would flatten the input, thus providing a much more **effective way to model spatio-temporal dynamics**.

# How ConvLSTM Works

The core innovation of ConvLSTM lies in how it adapts the internal operations of an LSTM cell to handle spatial data directly.

## The 5D Input

Before diving into the mechanics, let's understand the typical input data shape for a ConvLSTM layer. It's usually a 5D tensor with the following dimensions:

`(batch_size, timesteps, height, width, channels)`

Let's break down each dimension:

* `batch_size`: The number of independent sequences processed simultaneously (for parallel computation).
* `timesteps`: The length of the sequence (e.g., the number of video frames in a clip).
* `height`: The spatial height of each input frame/grid.
* `width`: The spatial width of each input frame/grid.
* `channels`: The number of feature channels for each input frame/grid (e.g., 3 for an RGB image, 1 for grayscale).

This 5D structure is crucial because it allows the ConvLSTM to operate on the full spatial dimensions of your data at each time step.

## Convolutions Inside the LSTM loop

In a traditional LSTM, the gates (Forget, Input, Output) and the candidate cell state generation perform matrix multiplications with their inputs (current input and previous hidden state).

In ConvLSTM, all these **matrix multiplications are fundamentally replaced by convolutional operations**. This means that ,the Forget Gate, the Input Gate, the Output Gate and the candidate cell state generation all utilises 2D or 3D, depending on your data and implementation convolutional filters. This is how CNNs work inside the LSTM loop.

![](convlstm.png)

: Figure 3: A conceptual diagram of a ConvLSTM cell, highlighting how all internal matrix multiplications are replaced by convolutional operations (denoted by the orange asterisk)

As shown in Figure 3, the inputs $X_t$, $H_{t-1}$, and $C_{t-1}$ are no longer flattened vectors. Instead, they are spatial feature maps (or the raw image/frame), and the weights ($W_{XC}$, $W_{XH}$, etc.) are now convolutional kernals. These kernels slide across the spatial dimensions of the input, preserving the spatial hierarchy.

## Preserving Spatial Hierarchy and Learning Temporal Dependencies

By replacing matrix multiplications with convolutions, ConvLSTM achieves two critical advantages:

1. **Spatial Feature Learning**: Just like a CNN, it can learn and extract spatial features (edges, textures, object parts) from each input frame or spatial slice at every time step.  
2. **Temporal Dependency Modeling**: Like an LSTM, it can maintain and update an internal cell state across time, allowing it to learn long-term dependencies and patterns in the sequence.

This means that instead of just learning what to remember or forget (as in a regular LSTM), a ConvLSTM learns what spatial patterns to remember or forget over time, and how those patterns evolve.

# Simple Implementation on ConvLSTM Next Frame Prediction

Before we dive into the ConvLSTM implementation, let's take a look at the dataset. This section uses a short video clip, randomly sourced online. While its exact subject matter is unknown (my best guess involves some form of gaseous movement!), it provides an excellent and visually engaging sequence for our ConvLSTM model.

{{< video dataset.mp4 >}}

In [None]:
import numpy as np
import pandas as pd
from PIL import Image, ImageOps
import os
import matplotlib.pyplot as plt

src_path = r'./dataset'

length = len(os.listdir(src_path))

images = []
for i in range(length):
    img = Image.open(f'{src_path}/img{i}.png')
    img=img.crop((0,120,640,360))
    img=img.convert('RGB')
    img=img.resize((160,60),Image.Resampling.LANCZOS)
    img=np.array(img)/255
    img=img.reshape(60,160,3)
    images.append(img)

In [None]:
def train_test(images,step_size,start,stop):
    seqx=[]
    seqy=[]
    x=[]
    y=[]

    sample_size=stop-start

    for i in range(start,stop):
        endx=i+step_size
        endy=i+step_size*2

        seqx.append(images[i:endx])
        seqy.append(images[endx:endy])

        x.append(np.array(seqx))
        y.append(np.array(seqy))

        seqx.clear()
        seqy.clear()

    return np.array(x).reshape(sample_size,step_size, 60, 160, 3),np.array(y).reshape(sample_size,step_size,60, 160, 3)

step_size = 3

xtrain,ytrain = train_test(images,step_size,0,800)
xtest,ytest = train_test(images,step_size,900,990)
xval,yval = train_test(images,step_size,800,900)
print(f'train: x{xtrain.shape},y{ytrain.shape}  test: x{xtest.shape},y:{ytest.shape}')

In [None]:
from tensorflow.keras import Sequential
from tensorflow.keras.layers import ConvLSTM2D,BatchNormalization,Conv3D
from tensorflow.keras.optimizers import Adam

def model():
    seq = Sequential()

    seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
                       input_shape=(step_size, 60, 160, 3),
                       padding='same', return_sequences=True,activation="relu"))
    seq.add(BatchNormalization())

    seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
                       padding='same', return_sequences=True,activation="relu"))
    seq.add(BatchNormalization())

    seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
                       padding='same', return_sequences=True,activation="relu"))
    seq.add(BatchNormalization())

    seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
                       padding='same', return_sequences=True,activation="relu"))
    seq.add(BatchNormalization())

    seq.add(Conv3D(filters=3, kernel_size=(3, 3, 3),
                   activation='sigmoid',
                   padding='same', data_format='channels_last'))
    seq.compile(loss='mae', optimizer=Adam(learning_rate=0.001))

    return seq

m=model()
m.summary()

m.fit(xtrain,ytrain,epochs=15,batch_size=5,
validation_data=(xval, yval))

In [None]:
def predictions(model,xtest,ytest):
    choice = np.random.choice(len(xtest))
    print(choice)
    pre = model.predict(xtest[choice].reshape(1,step_size, 60, 160, 3))
    pre=pre*255

    fig, axes = plt.subplots(2, 3, figsize=(20, 4))

    for time, ax in enumerate(axes[0]):
        ax.imshow(np.squeeze(ytest[choice][time]))

        ax.set_title(f"Ground Truth {time+1}")
        ax.axis("off")

    for time, ax in enumerate(axes[1]):
        im=Image.fromarray(np.uint8(pre[0][time]),'RGB')
        ax.imshow(im)
        ax.set_title(f"Prediction {time+1}")
        ax.axis("off")

predictions(m,xtest,ytest)

# Real-World Applications of ConvLSTM

The ability of ConvLSTM to concurrently model spatial and temporal dynamics makes it suitable for a wide array of real-world applications:

* **Video Prediction / Next Frame Prediction**: This is a direct application, as shown in our example. ConvLSTM can learn the motion and evolution of objects and patterns within a video, generating realistic future frames. This has implications for:
    * **Traffic Flow Prediction**: Anticipating vehicle movement on roads.  
    * **Sports Analysis**: Predicting ball trajectories or player movements.  
* **Weather Forecasting and Climate Modeling**: Predicting spatio-temporal patterns of meteorological phenomena like rainfall, temperature maps, or storm movements over time. Radar data, which is essentially a sequence of spatial maps, is a perfect fit.  
* **Action Recognition in Videos**: Identifying human actions or activities (e.g., walking, running, clapping) by analyzing the sequence of spatial features extracted from video frames.  
* **Medical Image Analysis (Time Series)**: Analyzing sequences of medical scans (e.g., fMRI for brain activity, dynamic MRI for organ movement, ultrasound videos) to detect changes, track disease progression, or diagnose conditions. Examples include:
    * **Tumor Growth Monitoring**: Tracking changes in tumor size and shape over time.
    * **Cardiac Motion Analysis**: Assessing heart wall movement from cine MRI sequences.  