In [None]:
!pip install qiskit qiskit-aer matplotlib pillow

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
from matplotlib.animation import FuncAnimation
from matplotlib.animation import PillowWriter
from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector
from qiskit.quantum_info import partial_trace
from qiskit.visualization import plot_bloch_multivector


def idx(i,j,L): return i*L+j
def neighbors(i,j,L):
    out=[]
    if i>0: out.append((i-1,j))
    if i<L-1: out.append((i+1,j))
    if j>0: out.append((i,j-1))
    if j<L-1: out.append((i,j+1))
    return out

def rule():
    U = QuantumCircuit(2)
    U.h(0)
    U.cx(0,1)
    U.rx(np.pi/3,1)
    U.cz(0,1)
    return U.to_instruction()

def apply(qc,u,L):
    for i in range(L):
        for j in range(L):
            q = idx(i,j,L)
            nbs = neighbors(i,j,L)
            if len(nbs)==0: continue
            ni,nj = nbs[0]
            qc.append(u,[q, idx(ni,nj,L)])

def simulate(L=2, steps=4):
    n = L*L
    u = rule()
    qc = QuantumCircuit(n)
    for i in range(L):
        qc.x(idx(i,L//2,L))
        qc.h(idx(L//2,i,L))
    evol = []
    state = Statevector.from_instruction(qc)
    evol.append(state)
    for t in range(steps):
        apply(qc,u,L)
        state = Statevector.from_instruction(qc)
        evol.append(state)
    return evol

def bloch_image(sv, L):
    imgs = []
    for i in range(L*L):
        reduced = partial_trace(sv,[q for q in range(L*L) if q!=i])
        fig = plot_bloch_multivector(reduced)
        fig.canvas.draw()
        w,h = fig.canvas.get_width_height()
        buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
        img = buf.reshape(h, w, 4)[:,:,:3]
        plt.close(fig)
        imgs.append(img)
    return imgs

def animate_bloch_images(evol,L):
    fig, ax = plt.subplots()
    ax.axis('off')
    frame = ax.imshow(np.zeros((100,100,3),dtype=np.uint8))
    def update(t):
        imgs = bloch_image(evol[t],L)
        rows = []
        for r in range(L):
            row = np.hstack([imgs[r*L+c] for c in range(L)])
            rows.append(row)
        full = np.vstack(rows)
        frame.set_data(full)
        return [frame]
    anim = FuncAnimation(fig, update, frames=len(evol), interval=800)
    plt.close(fig)
    return anim

def animate(evol):
    fig, ax = plt.subplots(figsize=(5, 5))
    frame = ax.imshow(evol[0], cmap="inferno", vmin=0, vmax=1)
    def update(t):
        frame.set_data(evol[t])
        ax.set_title(f"t = {t}")
        return [frame]
    anim = FuncAnimation(fig, update, frames=len(evol), interval=200)
    plt.close(fig)
    return anim

L=4
STEPS=15
evol = simulate(L,STEPS)
anim = animate_bloch_images(evol,L)
gif_writer = PillowWriter(fps=2)
anim.save("simulation.gif", writer=gif_writer)
HTML(anim.to_jshtml())
HTML(animate(simulate(4, 20)).to_jshtml())