# TSA Chapter 8: LSTM Cell Architecture

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/QuantLet/TSA/blob/main/TSA_ch8/TSA_ch8_lstm_architecture/TSA_ch8_lstm_architecture.ipynb)

Detailed LSTM cell architecture with forget, input, candidate, and output gates.

In [None]:
!pip install numpy matplotlib -q

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Circle, Rectangle
import warnings
warnings.filterwarnings('ignore')

In [None]:
COLORS = {'blue': '#1A3A6E', 'red': '#DC3545', 'green': '#2E7D32', 'orange': '#E67E22', 'gray': '#666666', 'purple': '#8E44AD'}
BLUE, RED, GREEN, ORANGE, GRAY, PURPLE = COLORS['blue'], COLORS['red'], COLORS['green'], COLORS['orange'], COLORS['gray'], COLORS['purple']
ACCENT_BLUE = '#4A90D9'
plt.rcParams.update({
    'figure.facecolor': 'none', 'axes.facecolor': 'none', 'savefig.facecolor': 'none',
    'savefig.transparent': True, 'axes.spines.top': False, 'axes.spines.right': False,
    'axes.grid': False, 'font.size': 10, 'axes.titlesize': 12, 'axes.labelsize': 10,
    'xtick.labelsize': 9, 'ytick.labelsize': 9, 'legend.fontsize': 9, 'figure.dpi': 150,
    'lines.linewidth': 1.2, 'axes.linewidth': 0.6, 'legend.facecolor': 'none',
    'legend.framealpha': 0, 'legend.edgecolor': 'none',
})
def save_chart(fig, name):
    fig.savefig(f'{name}.pdf', bbox_inches='tight', transparent=True, dpi=150)
    fig.savefig(f'{name}.png', bbox_inches='tight', transparent=True, dpi=150)
    print(f'Saved: {name}')

In [None]:
fig, ax = plt.subplots(figsize=(15, 7.5), dpi=200)
ax.set_xlim(-1.5, 14)
ax.set_ylim(-1, 8)
ax.axis('off')
ax.set_aspect('equal')

LIGHT_BG = '#F8FAFC'

# Main cell body
cell_rect = FancyBboxPatch((1.5, 1), 10, 5.5, boxstyle='round,pad=0.15,rounding_size=0.5',
                            facecolor=LIGHT_BG, edgecolor=BLUE, linewidth=3.5)
ax.add_patch(cell_rect)

# Cell state line (top)
ax.plot([-1, 13.5], [6, 6], color=BLUE, lw=5, solid_capstyle='round', zorder=2)
ax.text(-1.3, 6, '$C_{t-1}$', ha='right', va='center', fontsize=16, fontweight='bold', color=BLUE)
ax.text(13.8, 6, '$C_t$', ha='left', va='center', fontsize=16, fontweight='bold', color=BLUE)
ax.annotate('', xy=(13.5, 6), xytext=(12.8, 6),
            arrowprops=dict(arrowstyle='-|>', color=BLUE, lw=4, mutation_scale=25))

# Hidden state line
ax.plot([-1, 1.5], [2, 2], color=PURPLE, lw=4)
ax.plot([11.5, 13.5], [2, 2], color=PURPLE, lw=4)
ax.text(-1.3, 2, '$h_{t-1}$', ha='right', va='center', fontsize=16, fontweight='bold', color=PURPLE)
ax.text(13.8, 2, '$h_t$', ha='left', va='center', fontsize=16, fontweight='bold', color=PURPLE)
ax.annotate('', xy=(13.5, 2), xytext=(12.8, 2),
            arrowprops=dict(arrowstyle='-|>', color=PURPLE, lw=4, mutation_scale=25))

def draw_gate(x, y, label, color, w=1.4, h=1.0):
    rect = FancyBboxPatch((x - w/2, y - h/2), w, h, boxstyle='round,pad=0.08',
                           facecolor=color, edgecolor='white', linewidth=3, alpha=0.95, zorder=5)
    ax.add_patch(rect)
    ax.text(x, y, label, ha='center', va='center', fontsize=14, fontweight='bold', color='white', zorder=6)

def draw_op(x, y, symbol, color, radius=0.4):
    circle = Circle((x, y), radius, facecolor='white', edgecolor=color, linewidth=3, zorder=5)
    ax.add_patch(circle)
    ax.text(x, y, symbol, ha='center', va='center', fontsize=15, fontweight='bold', color=color, zorder=6)

def draw_arrow(start, end, color, lw=2.5):
    ax.annotate('', xy=end, xytext=start,
                arrowprops=dict(arrowstyle='-|>', color=color, lw=lw, mutation_scale=18))

gate_y = 3.5

# Gates
draw_gate(3, gate_y, '$\\sigma$', RED)
ax.text(3, 2.2, 'Forget\nGate', ha='center', va='center', fontsize=11, color=RED, fontweight='bold', linespacing=0.9)

draw_gate(5.5, gate_y, '$\\sigma$', GREEN)
ax.text(5.5, 2.2, 'Input\nGate', ha='center', va='center', fontsize=11, color=GREEN, fontweight='bold', linespacing=0.9)

draw_gate(7.5, gate_y, 'tanh', ORANGE)
ax.text(7.5, 2.2, 'Candidate', ha='center', va='center', fontsize=11, color=ORANGE, fontweight='bold')

draw_gate(10, gate_y, '$\\sigma$', ACCENT_BLUE)
ax.text(10, 2.2, 'Output\nGate', ha='center', va='center', fontsize=11, color=ACCENT_BLUE, fontweight='bold', linespacing=0.9)

# Operations on cell state
draw_op(3, 6, '\u00d7', RED)
draw_op(6.5, 6, '+', GREEN)
draw_op(6.5, 4.8, '\u00d7', ORANGE, radius=0.35)
draw_op(10, 5, 'tanh', BLUE, radius=0.45)
draw_op(10, 2, '\u00d7', ACCENT_BLUE)

# Connections
ax.plot([3, 3], [4.0, 5.6], color=RED, lw=2.5, zorder=3)
draw_arrow((3, 5.0), (3, 5.55), RED, lw=2.5)

ax.plot([5.5, 5.5], [4.0, 4.8], color=GREEN, lw=2.5, zorder=3)
ax.plot([5.5, 6.1], [4.8, 4.8], color=GREEN, lw=2.5, zorder=3)

ax.plot([7.5, 7.5], [4.0, 4.8], color=ORANGE, lw=2.5, zorder=3)
ax.plot([7.5, 6.9], [4.8, 4.8], color=ORANGE, lw=2.5, zorder=3)

ax.plot([6.5, 6.5], [5.15, 5.6], color=GREEN, lw=2.5, zorder=3)
draw_arrow((6.5, 5.3), (6.5, 5.55), GREEN, lw=2.5)

ax.plot([10, 10], [6, 5.45], color=BLUE, lw=2.5, zorder=3)
draw_arrow((10, 5.7), (10, 5.48), BLUE, lw=2.5)

ax.plot([10, 10], [4.55, 2.4], color=BLUE, lw=2.5, zorder=3)
draw_arrow((10, 3.0), (10, 2.42), BLUE, lw=2.5)

ax.plot([9.3, 9.3], [3.5, 2], color=ACCENT_BLUE, lw=2.5, zorder=3)
ax.plot([9.3, 9.6], [3.5, 3.5], color=ACCENT_BLUE, lw=2.5, zorder=3)
ax.plot([9.6, 9.6], [2, 2], color=ACCENT_BLUE, lw=2.5, zorder=3)

ax.plot([10.4, 11.5], [2, 2], color=PURPLE, lw=4, zorder=3)

# Input x_t
ax.annotate('', xy=(6, 1), xytext=(6, -0.5),
            arrowprops=dict(arrowstyle='-|>', color=GREEN, lw=3.5, mutation_scale=22))
ax.text(6, -0.8, '$x_t$', ha='center', va='center', fontsize=16, fontweight='bold', color=GREEN)

# Input distribution lines
input_y = 1.2
ax.plot([3, 10], [input_y, input_y], color=GRAY, lw=2, ls='--', alpha=0.6, zorder=1)
ax.plot([3, 3], [input_y, 3.0], color=GRAY, lw=2, ls='--', alpha=0.6, zorder=1)
ax.plot([5.5, 5.5], [input_y, 3.0], color=GRAY, lw=2, ls='--', alpha=0.6, zorder=1)
ax.plot([7.5, 7.5], [input_y, 3.0], color=GRAY, lw=2, ls='--', alpha=0.6, zorder=1)
ax.plot([10, 10], [input_y, 3.0], color=GRAY, lw=2, ls='--', alpha=0.6, zorder=1)
ax.plot([6, 6], [input_y, 1], color=GRAY, lw=2, ls='--', alpha=0.6, zorder=1)

ax.plot([1.5, 3], [2, 2], color=GRAY, lw=2, ls=':', alpha=0.5, zorder=1)
ax.plot([3, 3], [2, 3.0], color=GRAY, lw=2, ls=':', alpha=0.5, zorder=1)

# Labels
ax.text(0.5, 7.2, 'Cell State $C_t$', ha='center', va='center', fontsize=13,
        color=BLUE, fontweight='bold', style='italic',
        bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor=BLUE, alpha=0.9))
ax.text(0.5, 6.6, '(Long-term Memory)', ha='center', va='center', fontsize=10,
        color=BLUE, style='italic')
ax.text(12.5, 0.8, 'Hidden State $h_t$', ha='center', va='center', fontsize=13,
        color=PURPLE, fontweight='bold', style='italic',
        bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor=PURPLE, alpha=0.9))
ax.text(12.5, 0.3, '(Short-term Memory)', ha='center', va='center', fontsize=10,
        color=PURPLE, style='italic')

# Legend box
legend_x = -0.5
legend_y_start = 4.5
legend_box = FancyBboxPatch((legend_x - 0.6, legend_y_start - 2.4), 2.4, 2.7,
                             boxstyle='round,pad=0.1', facecolor='white',
                             edgecolor=GRAY, linewidth=1, alpha=0.95)
ax.add_patch(legend_box)

ax.text(legend_x + 0.5, legend_y_start + 0.1, 'Gates:', ha='center', va='center', fontsize=11, fontweight='bold')
items = [(RED, 'Forget'), (GREEN, 'Input'), (ORANGE, 'Candidate'), (ACCENT_BLUE, 'Output')]
for i, (color, label) in enumerate(items):
    y_pos = legend_y_start - 0.4 - i * 0.55
    rect = Rectangle((legend_x - 0.3, y_pos - 0.15), 0.5, 0.3, facecolor=color, edgecolor='white', lw=1.5)
    ax.add_patch(rect)
    ax.text(legend_x + 0.4, y_pos, label, ha='left', va='center', fontsize=10, color=color, fontweight='bold')

plt.tight_layout()
save_chart(fig, 'ch8_lstm_architecture')
plt.show()