# Simulated Annealing(焼きなまし)
## 扱う題材
結晶構造の緩和を題材とする。
以下のリンクの課題を参考に実装する。  
[インターン課題](https://github.com/pfnet/intern-coding-tasks/blob/master/2022/je06/JE06_task.ipynb)

In [28]:
from typing import Tuple, Optional
from itertools import product
import copy
import time
from dataclasses import dataclass
import numpy as np
import ase
import ase.optimize
from ase.calculators.calculator import Calculator, all_changes
from ase.visualize import view

In [32]:
class LennardJonesI22:
    def __init__(self):
        sigma_single = np.array([2.0 + 0.2 * x for x in range(8)])
        epsilon_single = np.array([0.1 + 0.1 * x for x in range(8)])

        self.sigma_matrix = 0.5 * (sigma_single[None, :] + sigma_single[:, None])
        self.epsilon_matrix = np.sqrt(epsilon_single[None, :] * epsilon_single[:, None])

    def calculate(self, atom_type: np.ndarray, positions: np.ndarray) -> Tuple[float, np.ndarray]:
        assert len(positions.shape) == 2
        assert positions.shape[1] == 3
        n_atoms = positions.shape[0]
        
        assert len(atom_type.shape) == 1
        assert atom_type.shape[0] == n_atoms

        an_axis0 = np.repeat(atom_type[None, :], n_atoms, axis=0)
        an_axis1 = np.repeat(atom_type[:, None], n_atoms, axis=1)
        sigma_pairs = self.sigma_matrix[an_axis0, an_axis1]
        epsilon_pairs = self.epsilon_matrix[an_axis0, an_axis1]
    
        x1 = positions[None, :, :]
        x2 = positions[:, None, :]
        x_diff = x2 - x1
        rsq = np.sum(np.square(x_diff), axis=2)
        rsq_reciprocal = np.reciprocal(rsq, out=np.zeros_like(rsq), where=(rsq != 0.0))

        sigma_by_r2 = np.square(sigma_pairs) * rsq_reciprocal
        sigma_by_r6 = np.power(sigma_by_r2, 3)
        sigma_by_r12 = np.square(sigma_by_r6)
        e_pairs = 4.0 * epsilon_pairs * (sigma_by_r12 - sigma_by_r6)
        f_pairs_by_r = -24.0 * epsilon_pairs * (2 * sigma_by_r12 - sigma_by_r6) * rsq_reciprocal
        f_pairs = f_pairs_by_r[:, :, None] * x_diff
    
        output_atoms = 0.5 * np.sum(e_pairs, axis=1)
        output = float(np.sum(output_atoms).item())
        if np.any(rsq + np.identity(rsq.shape[0]) == 0.0):  # Same position
            output = float("inf")
        grad = 0.5 * (np.sum(f_pairs, axis=1) - np.sum(f_pairs, axis=0))
        
        return output, grad
    


class LennardJonesI22Calculator(Calculator):
    implemented_properties = ['energy', 'forces', 'free_energy']
    
    def __init__(self, **kwargs):
        Calculator.__init__(self, **kwargs)
        self.lj_core = LennardJonesI22()
        
    def calculate(self, atoms=None, properties=None, system_changes=all_changes):
        if properties is None:
            properties = self.implemented_properties

        Calculator.calculate(self, atoms, properties, system_changes)
        e_total, f_atoms = self.lj_core.calculate(atoms.get_atomic_numbers(), atoms.get_positions())
        self.results["energy"] = e_total
        self.results["free_energy"] = e_total
        self.results["forces"] = -f_atoms
        

@dataclass
class Config:
    time_limit: float = 100.0
    step: float = 4.0
    init_temp = 10
    end_temp = 1
    n_atom = 64


@dataclass
class State:
    positions: Optional[np.ndarray]
    atomic_numbers: Optional[np.ndarray]
    energy: Optional[float]


# 原子の重心を原点に設定
def normalize_positions(positions: np.ndarray):
    center = positions.mean(axis=0)
    positions = positions - center
    return positions

# 初期解生成
def make_init_sol(base_atomic_array: np.ndarray, base_positions: np.ndarray, calculator: Calculator) -> State:
    atomic_numbers = base_atomic_array
    np.random.shuffle(atomic_numbers)
    noises = np.random.rand(64, 3) - 0.50
    pre_positions = base_positions + noises
    pre_positions = normalize_positions(pre_positions)

    atoms = ase.Atoms(numbers=atomic_numbers, positions=pre_positions)
    atoms.calc = calculator

    opt = ase.optimize.BFGS(atoms, logfile=None)
    opt.run(fmax=0.001)
    
    post_positions = atoms.positions
    energy = atoms.get_potential_energy()

    ret = State(
        positions=post_positions,
        atomic_numbers=atomic_numbers,
        energy=energy
    )

    return ret


def step(cur_state: State, calculator: Calculator) -> State:
    ret = copy.deepcopy(cur_state)
    r = np.random.randint(10)

    # ２点スワップ
    if r < 4:
        num_swap = np.random.randint(1, 5)
        for _ in range(num_swap):
            ok = False
            while not ok:
                p1, p2 = np.random.choice(Config.n_atom, 2, replace=False)
                a1 = ret.atomic_numbers[p1]
                a2 = ret.atomic_numbers[p2]
                if a1 != a2:
                    ret.atomic_numbers[p1], ret.atomic_numbers[p2] = ret.atomic_numbers[p2], ret.atomic_numbers[p1]
                    ok = True

    # ３点スワップ(左シフト)
    elif r < 7:
        ok = False
        while not ok:
            p1, p2, p3 = np.random.choice(Config.n_atom, 3, replace=False)
            a1 = ret.atomic_numbers[p1]
            a2 = ret.atomic_numbers[p2]
            a3 = ret.atomic_numbers[p3]
            if a1 != a2 and a2 != a3 and a3 != a1:
                ret.atomic_numbers[p1], ret.atomic_numbers[p2] = ret.atomic_numbers[p2], ret.atomic_numbers[p1]
                ret.atomic_numbers[p2], ret.atomic_numbers[p3] = ret.atomic_numbers[p3], ret.atomic_numbers[p2]
                ok = True
    
    # ３点スワップ（右シフト）
    else:
        ok = False
        while not ok:
            p1, p2, p3 = np.random.choice(Config.n_atom, 3, replace=False)
            a1 = ret.atomic_numbers[p1]
            a2 = ret.atomic_numbers[p2]
            a3 = ret.atomic_numbers[p3]
            if a1 != a2 and a2 != a3 and a3 != a1:
                ret.atomic_numbers[p1], ret.atomic_numbers[p3] = ret.atomic_numbers[p3], ret.atomic_numbers[p1]
                ret.atomic_numbers[p2], ret.atomic_numbers[p3] = ret.atomic_numbers[p3], ret.atomic_numbers[p2]
                ok = True


    positions = normalize_positions(ret.positions)

    # relax
    atoms = ase.Atoms(numbers=ret.atomic_numbers, positions=positions)
    atoms.calc = calculator

    opt = ase.optimize.BFGS(atoms, logfile=None)
    opt.run(fmax=0.001)
    
    post_positions = atoms.positions
    energy = atoms.get_potential_energy()

    ret.positions = post_positions
    ret.energy = energy
    
    return ret

In [33]:
def run_SA() -> State:
    start_time = time.time()
    cur_time = None
    base_atomic_array = np.array([0, 1, 2, 3, 4, 5, 6, 7] * 8)
    intervals = np.arange(start=0.0, stop=Config.step * 4, step=Config.step)
    base_positions = np.array(list(product(intervals, intervals, intervals)))
    calculator = LennardJonesI22Calculator()


    cur_state = make_init_sol(base_atomic_array, base_positions, calculator)
    best_state = copy.deepcopy(cur_state)
    
    cur_time = time.time()
    iter = 0
    
    while cur_time - start_time <= Config.time_limit:
        iter += 1
        alpha = (cur_time - start_time) / Config.time_limit
        cur_temp = Config.init_temp - (Config.init_temp - Config.end_temp) * alpha

        cur_state = step(cur_state, calculator)
        # print(f"cur, best: {cur_state.energy}, {best_state.energy}")
        if cur_state.energy < best_state.energy:
            print(f"Iteration {iter:06d} : Lowest Energy is updated!! {best_state.energy:4.3f} -> {cur_state.energy:4.3f}")
            
            best_state = copy.deepcopy(cur_state)

        elif np.exp(-(cur_state.energy - best_state.energy) / cur_temp) < np.random.random():
            cur_state = copy.deepcopy(best_state)
            

        cur_time = time.time()


    return best_state
    

if __name__ == "__main__":
    best_state = run_SA()

Iteration 000001 : Lowest Energy is updated!! -142.828 -> -142.980
Iteration 000029 : Lowest Energy is updated!! -142.980 -> -143.283
Iteration 000030 : Lowest Energy is updated!! -143.283 -> -143.410
Iteration 000037 : Lowest Energy is updated!! -143.410 -> -143.849
Iteration 000039 : Lowest Energy is updated!! -143.849 -> -143.895
Iteration 000040 : Lowest Energy is updated!! -143.895 -> -144.227
Iteration 000060 : Lowest Energy is updated!! -144.227 -> -144.714
Iteration 000066 : Lowest Energy is updated!! -144.714 -> -144.734
Iteration 000070 : Lowest Energy is updated!! -144.734 -> -145.075
Iteration 000076 : Lowest Energy is updated!! -145.075 -> -145.930
Iteration 000106 : Lowest Energy is updated!! -145.930 -> -146.514
Iteration 000115 : Lowest Energy is updated!! -146.514 -> -146.520


In [35]:
atoms = ase.Atoms(numbers=best_state.atomic_numbers+20, positions=best_state.positions)
view(atoms, viewer='x3d')