# Chapter 9: Advanced Transformations and Supercells

## Learning Objectives
- Build supercells using transformation matrices
- Apply strain to crystals
- Work with non-orthogonal cells
- Transform between conventional and primitive cells

---

## 9.1 Supercell Transformation Matrices

A **supercell** is created by combining multiple unit cells. Any supercell can be described by a transformation matrix **P**:

$$\begin{pmatrix} \mathbf{a'} \\ \mathbf{b'} \\ \mathbf{c'} \end{pmatrix} = \mathbf{P} \cdot \begin{pmatrix} \mathbf{a} \\ \mathbf{b} \\ \mathbf{c} \end{pmatrix}$$

Where **P** is a 3×3 integer matrix.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from typing import List, Tuple, Optional
from dataclasses import dataclass

@dataclass
class LatticeParameters:
    a: float
    b: float
    c: float
    alpha: float = 90.0
    beta: float = 90.0
    gamma: float = 90.0

def lattice_vectors_from_parameters(params: LatticeParameters) -> np.ndarray:
    alpha = np.radians(params.alpha)
    beta = np.radians(params.beta)
    gamma = np.radians(params.gamma)
    
    a_vec = np.array([params.a, 0, 0])
    bx = params.b * np.cos(gamma)
    by = params.b * np.sin(gamma)
    b_vec = np.array([bx, by, 0])
    
    cx = params.c * np.cos(beta)
    cy = params.c * (np.cos(alpha) - np.cos(beta)*np.cos(gamma)) / np.sin(gamma)
    cz = np.sqrt(max(0, params.c**2 - cx**2 - cy**2))
    c_vec = np.array([cx, cy, cz])
    
    return np.array([a_vec, b_vec, c_vec])

class Crystal:
    """Crystal structure."""
    ELEMENTS = {
        'Cu': {'color': 'brown', 'radius': 1.28},
        'Fe': {'color': 'orange', 'radius': 1.26},
        'Si': {'color': 'gold', 'radius': 1.17},
        'O': {'color': 'red', 'radius': 0.73},
        'Mg': {'color': 'green', 'radius': 1.60},
    }
    
    def __init__(self, name: str, lattice_vectors: np.ndarray):
        self.name = name
        self.lattice_vectors = np.array(lattice_vectors, dtype=float)
        self.basis: List[Tuple[str, np.ndarray]] = []
    
    @classmethod
    def from_parameters(cls, name: str, params: LatticeParameters) -> 'Crystal':
        vectors = lattice_vectors_from_parameters(params)
        return cls(name, vectors)
    
    def add_atom(self, symbol: str, fractional: np.ndarray) -> None:
        self.basis.append((symbol, np.array(fractional, dtype=float)))
    
    def add_atoms(self, symbol: str, positions: List[np.ndarray]) -> None:
        for pos in positions:
            self.add_atom(symbol, pos)
    
    @property
    def n_atoms(self) -> int:
        return len(self.basis)
    
    @property
    def volume(self) -> float:
        return abs(np.linalg.det(self.lattice_vectors))
    
    def fractional_to_cartesian(self, frac: np.ndarray) -> np.ndarray:
        return frac @ self.lattice_vectors
    
    def cartesian_to_fractional(self, cart: np.ndarray) -> np.ndarray:
        return cart @ np.linalg.inv(self.lattice_vectors)
    
    def get_cartesian_positions(self) -> Tuple[List[str], np.ndarray]:
        symbols = [s for s, _ in self.basis]
        frac = np.array([p for _, p in self.basis])
        cart = self.fractional_to_cartesian(frac)
        return symbols, cart
    
    def copy(self) -> 'Crystal':
        new = Crystal(self.name, self.lattice_vectors.copy())
        for sym, pos in self.basis:
            new.add_atom(sym, pos.copy())
        return new

# Build FCC copper
def build_fcc(element: str, a: float) -> Crystal:
    crystal = Crystal.from_parameters(f"{element}_FCC",
                                       LatticeParameters(a, a, a, 90, 90, 90))
    fcc_positions = [[0, 0, 0], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]]
    crystal.add_atoms(element, fcc_positions)
    return crystal

cu = build_fcc('Cu', 3.615)
print(f"FCC Cu: {cu.n_atoms} atoms, V = {cu.volume:.2f} Å³")

## 9.2 Creating Supercells with Transformation Matrices

In [None]:
def make_supercell(crystal: Crystal, P: np.ndarray) -> Crystal:
    """Create supercell using transformation matrix P.
    
    The new lattice vectors are: new_vectors = P @ old_vectors
    
    Args:
        crystal: Original crystal
        P: 3x3 integer transformation matrix
    
    Returns:
        New Crystal object
    """
    P = np.array(P, dtype=int)
    det_P = int(round(abs(np.linalg.det(P))))
    
    if det_P == 0:
        raise ValueError("Transformation matrix is singular")
    
    # New lattice vectors
    new_vectors = P @ crystal.lattice_vectors
    supercell = Crystal(f"{crystal.name}_supercell", new_vectors)
    
    # P inverse (for coordinate transformation)
    P_inv = np.linalg.inv(P.astype(float))
    
    # Search range for lattice points
    search_max = np.max(np.abs(P)) + 2
    
    # Find all lattice points inside new cell
    for i in range(-search_max, search_max + det_P):
        for j in range(-search_max, search_max + det_P):
            for k in range(-search_max, search_max + det_P):
                # This lattice point in original cell fractional coords
                lattice_point = np.array([i, j, k])
                
                # Transform to new cell fractional coordinates
                new_frac = P_inv @ lattice_point
                
                # Check if inside [0, 1)
                eps = 1e-8
                if np.all(new_frac >= -eps) and np.all(new_frac < 1 - eps):
                    # Add all basis atoms
                    for symbol, basis_frac in crystal.basis:
                        # Position in original cell
                        atom_frac_orig = lattice_point + basis_frac
                        # Transform to new cell
                        atom_frac_new = P_inv @ atom_frac_orig
                        # Wrap to [0, 1)
                        atom_frac_new = atom_frac_new % 1.0
                        supercell.add_atom(symbol, atom_frac_new)
    
    return supercell

# Example: 2x2x2 supercell (diagonal P matrix)
P_222 = np.array([
    [2, 0, 0],
    [0, 2, 0],
    [0, 0, 2]
])

cu_222 = make_supercell(cu, P_222)
print(f"2×2×2 supercell: {cu_222.n_atoms} atoms (expected: 4×8=32)")
print(f"Volume ratio: {cu_222.volume / cu.volume:.1f} (expected: 8)")

In [None]:
# Non-diagonal supercell: rotated cell
# This creates a √2 × √2 × 1 supercell rotated 45° in xy-plane
P_rot45 = np.array([
    [1, 1, 0],
    [-1, 1, 0],
    [0, 0, 1]
])

cu_rot = make_supercell(cu, P_rot45)
print(f"√2×√2×1 rotated supercell: {cu_rot.n_atoms} atoms")
print(f"Old vectors:")
print(cu.lattice_vectors)
print(f"New vectors:")
print(cu_rot.lattice_vectors)

## 9.3 Conventional to Primitive Cell Transformations

FCC conventional cell has 4 atoms; primitive cell has 1 atom.

In [None]:
# Transformation matrices for centering types
CENTERING_TRANSFORMS = {
    'P': np.eye(3),  # Primitive (no change)
    'F': np.array([  # FCC: conventional -> primitive
        [0, 0.5, 0.5],
        [0.5, 0, 0.5],
        [0.5, 0.5, 0]
    ]),
    'I': np.array([  # BCC: conventional -> primitive  
        [-0.5, 0.5, 0.5],
        [0.5, -0.5, 0.5],
        [0.5, 0.5, -0.5]
    ]),
    'C': np.array([  # C-centered: conventional -> primitive
        [0.5, 0.5, 0],
        [-0.5, 0.5, 0],
        [0, 0, 1]
    ])
}

def conventional_to_primitive(crystal: Crystal, centering: str) -> Crystal:
    """Transform conventional cell to primitive cell."""
    P = CENTERING_TRANSFORMS[centering]
    P_inv = np.linalg.inv(P)
    
    # New lattice vectors
    new_vectors = P @ crystal.lattice_vectors
    primitive = Crystal(f"{crystal.name}_primitive", new_vectors)
    
    # Only take one atom (at origin)
    # In general, need to find atoms in primitive cell
    primitive.add_atom(crystal.basis[0][0], [0, 0, 0])
    
    return primitive

# Convert FCC conventional to primitive
cu_primitive = conventional_to_primitive(cu, 'F')
print(f"FCC conventional: {cu.n_atoms} atoms, V = {cu.volume:.2f} Å³")
print(f"FCC primitive: {cu_primitive.n_atoms} atom, V = {cu_primitive.volume:.2f} Å³")
print(f"Volume ratio: {cu.volume / cu_primitive.volume:.1f}")

In [None]:
def primitive_to_conventional(crystal: Crystal, centering: str) -> Crystal:
    """Transform primitive cell to conventional cell."""
    P = CENTERING_TRANSFORMS[centering]
    P_inv = np.linalg.inv(P)
    
    # Inverse transformation
    new_vectors = P_inv @ crystal.lattice_vectors
    conventional = Crystal(f"{crystal.name}_conventional", new_vectors)
    
    # Add atoms at all centering positions
    symbol = crystal.basis[0][0]
    if centering == 'F':
        positions = [[0,0,0], [0.5,0.5,0], [0.5,0,0.5], [0,0.5,0.5]]
    elif centering == 'I':
        positions = [[0,0,0], [0.5,0.5,0.5]]
    elif centering == 'C':
        positions = [[0,0,0], [0.5,0.5,0]]
    else:
        positions = [[0,0,0]]
    
    conventional.add_atoms(symbol, positions)
    return conventional

# Verify round-trip
cu_conv = primitive_to_conventional(cu_primitive, 'F')
print(f"Back to conventional: {cu_conv.n_atoms} atoms, V = {cu_conv.volume:.2f} Å³")

## 9.4 Applying Strain to Crystals

**Strain** is a deformation described by a strain tensor:

$$\varepsilon = \begin{pmatrix} \varepsilon_{xx} & \varepsilon_{xy} & \varepsilon_{xz} \\ \varepsilon_{xy} & \varepsilon_{yy} & \varepsilon_{yz} \\ \varepsilon_{xz} & \varepsilon_{yz} & \varepsilon_{zz} \end{pmatrix}$$

The deformation gradient is $\mathbf{F} = \mathbf{I} + \varepsilon$.

In [None]:
def apply_strain(crystal: Crystal, strain: np.ndarray) -> Crystal:
    """Apply strain tensor to crystal.
    
    Args:
        crystal: Original crystal
        strain: 3x3 strain tensor (symmetric)
    
    Returns:
        Strained crystal
    """
    strain = np.array(strain, dtype=float)
    
    # Deformation gradient
    F = np.eye(3) + strain
    
    # New lattice vectors
    new_vectors = crystal.lattice_vectors @ F.T
    
    strained = Crystal(f"{crystal.name}_strained", new_vectors)
    
    # Copy atoms (fractional coords unchanged)
    for sym, pos in crystal.basis:
        strained.add_atom(sym, pos.copy())
    
    return strained

# Example: 2% tensile strain along x
strain_x = np.array([
    [0.02, 0, 0],
    [0, 0, 0],
    [0, 0, 0]
])

cu_strained = apply_strain(cu, strain_x)
print(f"Original a: {np.linalg.norm(cu.lattice_vectors[0]):.4f} Å")
print(f"Strained a: {np.linalg.norm(cu_strained.lattice_vectors[0]):.4f} Å")
print(f"Expected: {3.615 * 1.02:.4f} Å")

In [None]:
# Hydrostatic strain (uniform expansion/compression)
def hydrostatic_strain(crystal: Crystal, delta: float) -> Crystal:
    """Apply hydrostatic (uniform) strain.
    
    Args:
        delta: Fractional volume change (positive = expansion)
    """
    # For volume change by (1+delta), linear strain is (1+delta)^(1/3) - 1
    linear_strain = (1 + delta)**(1/3) - 1
    strain = np.eye(3) * linear_strain
    return apply_strain(crystal, strain)

# Biaxial strain (strain in xy-plane, relax z)
def biaxial_strain(crystal: Crystal, delta: float, poisson: float = 0.3) -> Crystal:
    """Apply biaxial strain (e.g., thin film on substrate).
    
    Args:
        delta: In-plane strain
        poisson: Poisson's ratio
    """
    # Out-of-plane strain from Poisson effect
    eps_z = -2 * poisson / (1 - poisson) * delta
    
    strain = np.array([
        [delta, 0, 0],
        [0, delta, 0],
        [0, 0, eps_z]
    ])
    return apply_strain(crystal, strain)

# Example: 1% compressive biaxial strain
cu_biaxial = biaxial_strain(cu, -0.01, poisson=0.34)
print(f"Original cell: a = {np.linalg.norm(cu.lattice_vectors[0]):.4f} Å")
print(f"Biaxial strained:")
print(f"  a = {np.linalg.norm(cu_biaxial.lattice_vectors[0]):.4f} Å")
print(f"  c = {np.linalg.norm(cu_biaxial.lattice_vectors[2]):.4f} Å")

In [None]:
# Shear strain
def shear_strain(crystal: Crystal, gamma: float, plane: str = 'xy') -> Crystal:
    """Apply pure shear strain.
    
    Args:
        gamma: Shear angle (engineering strain)
        plane: 'xy', 'xz', or 'yz'
    """
    strain = np.zeros((3, 3))
    
    if plane == 'xy':
        strain[0, 1] = gamma / 2
        strain[1, 0] = gamma / 2
    elif plane == 'xz':
        strain[0, 2] = gamma / 2
        strain[2, 0] = gamma / 2
    elif plane == 'yz':
        strain[1, 2] = gamma / 2
        strain[2, 1] = gamma / 2
    
    return apply_strain(crystal, strain)

# 5% shear in xy-plane
cu_shear = shear_strain(cu, 0.05, 'xy')
print("Original vectors:")
print(cu.lattice_vectors)
print("\nSheared vectors:")
print(cu_shear.lattice_vectors)

## 9.5 Cell Reduction (Niggli Reduction)

Different choices of lattice vectors can describe the same lattice. **Niggli reduction** finds a standard, "shortest" cell.

In [None]:
def niggli_reduce(lattice_vectors: np.ndarray, tolerance: float = 1e-8) -> np.ndarray:
    """Reduce lattice vectors to Niggli cell.
    
    This is a simplified implementation for demonstration.
    """
    vectors = lattice_vectors.copy()
    
    # Iteratively reduce
    for _ in range(100):  # Max iterations
        changed = False
        
        # Calculate Gram matrix elements
        a, b, c = vectors
        A = np.dot(a, a)
        B = np.dot(b, b)
        C = np.dot(c, c)
        D = 2 * np.dot(b, c)
        E = 2 * np.dot(a, c)
        F = 2 * np.dot(a, b)
        
        # Condition 1: A <= B <= C
        if A > B + tolerance:
            vectors = vectors[[1, 0, 2]]  # Swap a, b
            changed = True
            continue
        if B > C + tolerance:
            vectors = vectors[[0, 2, 1]]  # Swap b, c
            changed = True
            continue
        
        # Condition 2: Signs of D, E, F should match or be zero
        # (simplified reduction)
        
        # Try to reduce vector lengths
        for i in range(3):
            for j in range(3):
                if i == j:
                    continue
                new_vec = vectors[i] - vectors[j]
                if np.dot(new_vec, new_vec) < np.dot(vectors[i], vectors[i]) - tolerance:
                    vectors[i] = new_vec
                    changed = True
                new_vec = vectors[i] + vectors[j]
                if np.dot(new_vec, new_vec) < np.dot(vectors[i], vectors[i]) - tolerance:
                    vectors[i] = new_vec
                    changed = True
        
        if not changed:
            break
    
    return vectors

# Test with a non-standard cell
# Create an equivalent but non-reduced cell
P_long = np.array([
    [2, 1, 0],
    [0, 1, 0],
    [0, 0, 1]
])

long_cell = P_long @ cu.lattice_vectors
print("Non-reduced cell vectors:")
print(long_cell)
print(f"Lengths: {[np.linalg.norm(v) for v in long_cell]}")

reduced = niggli_reduce(long_cell)
print("\nReduced cell vectors:")
print(reduced)
print(f"Lengths: {[np.linalg.norm(v) for v in reduced]}")

## 9.6 Affine Transformations for Crystals

In [None]:
def rotate_crystal(crystal: Crystal, axis: np.ndarray, angle: float) -> Crystal:
    """Rotate crystal around an axis through the origin.
    
    Args:
        axis: Rotation axis
        angle: Rotation angle in degrees
    """
    axis = np.array(axis, dtype=float)
    axis = axis / np.linalg.norm(axis)
    theta = np.radians(angle)
    
    # Rodrigues' formula
    K = np.array([
        [0, -axis[2], axis[1]],
        [axis[2], 0, -axis[0]],
        [-axis[1], axis[0], 0]
    ])
    R = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * K @ K
    
    # Rotate lattice vectors
    new_vectors = crystal.lattice_vectors @ R.T
    
    rotated = Crystal(f"{crystal.name}_rotated", new_vectors)
    for sym, pos in crystal.basis:
        rotated.add_atom(sym, pos.copy())
    
    return rotated

# Rotate Cu by 45° around z-axis
cu_rotated = rotate_crystal(cu, [0, 0, 1], 45)
print("Original vectors:")
print(cu.lattice_vectors)
print("\nRotated 45° around z:")
print(cu_rotated.lattice_vectors)

## 9.7 Visualizing Transformations

In [None]:
def plot_crystal(crystal: Crystal, ax: plt.Axes = None,
                  show_cell: bool = True, title: str = None) -> plt.Axes:
    """Plot crystal structure."""
    if ax is None:
        fig = plt.figure(figsize=(8, 8))
        ax = fig.add_subplot(111, projection='3d')
    
    symbols, positions = crystal.get_cartesian_positions()
    
    # Plot atoms
    unique = list(set(symbols))
    for elem in unique:
        mask = [s == elem for s in symbols]
        elem_pos = positions[mask]
        props = Crystal.ELEMENTS.get(elem, {'color': 'gray', 'radius': 1.0})
        ax.scatter(elem_pos[:, 0], elem_pos[:, 1], elem_pos[:, 2],
                   s=props['radius']*200, c=props['color'],
                   edgecolors='black', label=elem)
    
    # Draw cell
    if show_cell:
        a, b, c = crystal.lattice_vectors
        verts = [np.zeros(3), a, a+b, b, c, a+c, a+b+c, b+c]
        edges = [(0,1),(1,2),(2,3),(3,0),(4,5),(5,6),(6,7),(7,4),
                 (0,4),(1,5),(2,6),(3,7)]
        for i, j in edges:
            ax.plot3D(*zip(verts[i], verts[j]), 'k-', alpha=0.5)
    
    ax.set_xlabel('X (Å)')
    ax.set_ylabel('Y (Å)')
    ax.set_zlabel('Z (Å)')
    ax.set_title(title or crystal.name)
    ax.legend()
    
    return ax

# Compare original and strained
fig = plt.figure(figsize=(15, 5))

ax1 = fig.add_subplot(131, projection='3d')
plot_crystal(cu, ax1, title="Original FCC Cu")

ax2 = fig.add_subplot(132, projection='3d')
plot_crystal(cu_biaxial, ax2, title="Biaxial Strained")

ax3 = fig.add_subplot(133, projection='3d')
plot_crystal(cu_shear, ax3, title="Sheared")

plt.tight_layout()
plt.show()

---

## Practice Exercises

### Exercise 9.1: Create (110) Surface Cell

Use a transformation matrix to create a cell suitable for (110) surface calculations.

In [None]:
# YOUR CODE HERE
def create_110_cell(crystal: Crystal) -> Crystal:
    """Create cell with [110] as z-axis.
    
    New axes should be:
    a' = [001]
    b' = [1-10]
    c' = [110]
    """
    # TODO: Find the transformation matrix P
    pass

### Exercise 9.2: Strain Energy

Implement a function to calculate the expected strain energy using elastic constants.

In [None]:
# YOUR CODE HERE
def strain_energy_cubic(strain: np.ndarray, C11: float, C12: float, C44: float) -> float:
    """Calculate strain energy density for cubic crystal.
    
    Uses elastic constants in Voigt notation.
    
    Args:
        strain: 3x3 strain tensor
        C11, C12, C44: Elastic constants in GPa
    
    Returns:
        Energy density in J/m³
    """
    # TODO: Implement using elastic energy formula
    # E = (1/2) * C_ijkl * eps_ij * eps_kl
    pass

### Exercise 9.3: Epitaxial Matching

Find a supercell transformation that matches a film to a substrate.

In [None]:
# YOUR CODE HERE
def find_epitaxial_match(substrate_a: float, film_a: float,
                          max_cells: int = 5) -> Tuple[int, int, float]:
    """Find supercell size for epitaxial matching.
    
    Returns:
        (n_substrate, n_film, mismatch)
        where n_substrate substrate cells matches n_film film cells
        with minimum strain mismatch
    """
    # TODO: Implement
    pass

---

## Key Takeaways

1. **Transformation matrices** describe any supercell or cell change
2. **Primitive/conventional** transformations reduce or expand cells
3. **Strain tensors** describe crystal deformations
4. **Cell reduction** finds standard representations
5. **Integer transformations** preserve the lattice

## Next Chapter Preview

In Chapter 10, we'll learn how to **export structures** to common file formats (XYZ, PDB, CIF, POSCAR).