In [2]:
import sys,os

# go up a dir
os.chdir(os.path.dirname(os.path.abspath("./")))
print(os.getcwd())

/Users/matthewcox/Documents/UCL/MSc/Full_Phen_SOEN


In [3]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler,MinMaxScaler
from model.soen_model import SOENModel
from model_config_files.two_moons_config import TwoMoonsConfig


from utils.soen_model_utils import *
from tqdm import tqdm

## Equlibrium Propagation Experiments

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import os
import csv
import numpy as np
from model.soen_model import SOENModel
from model.model_config import SOENConfig
import copy

LABELSIZE = 16
FONTSIZE = 18

class SOENModelEP(SOENModel):
    def __init__(self, config: SOENConfig):
        super().__init__(config)
        self.is_weakly_clamped = False
        self.beta = 0.0
        self.target = None
        self.state_evolution = []
        self.learning_rule = "simple_state"
        self.rule_params = {
            "simple_state": {"beta": 1, "learning_rate": 1},
            "simple_flux": {"beta": 1, "learning_rate": 1},
            "EqProp": {"beta": 1, "learning_rate": 1},
            "EqProp_var1": {"beta": 1, "learning_rate": 1},
            "SOEN1": {"beta": 1, "learning_rate": 1},
        }
        self.run_to_equilibrium = config.run_to_equilibrium
        self.tol = config.tol

    def forward(self, x, y=None, free_steps=100, nudged_steps=20, beta=0.1, initial_state=None):
        batch_size = x.shape[0]
        if initial_state is None:
            s = torch.zeros(batch_size, self.num_total, device=x.device)
        else:
            s = initial_state.clone()

        s[:, :self.num_input] = x  # Clamp input

        # Free phase
        self.is_weakly_clamped = False
        self.beta = 0.0
        self.target = None

        s, phi = self.run_phase(s, x, free_steps)
        free_state = s.clone()
        free_phi = phi.clone()

        # Nudged phase
        self.is_weakly_clamped = True
        self.beta = beta
        self.target = y

        s, phi = self.run_phase(s, x, nudged_steps)
        nudged_state = s.clone()
        nudged_phi = phi.clone()

        return free_state, nudged_state, s[:, -self.num_output:], s, free_phi, nudged_phi

    def run_phase(self, s, x, max_steps):
        prev_s = s.clone()
        for step in range(max_steps):
            s, phi = self.step(s, x)
            self.state_evolution.append((s.clone(), phi.clone()))
            
            if self.run_to_equilibrium:
                diff = torch.max(torch.abs(s - prev_s))
                if diff < self.tol:
                    break
            prev_s = s.clone()
        
        return s, phi

    def step(self, s, x):
        J_masked = self.J * self.mask
        phi = torch.mm(s, J_masked.t()) + self.flux_offset

        if self.clip_phi:
            phi = torch.clamp(phi, -0.5, 0.5)

        ds = self.gamma * self.g(phi, s) - s / self.tau

        if self.is_weakly_clamped:
            output_idx = slice(-self.num_output, None)
            ds[:, output_idx] += self.beta * (self.target - s[:, output_idx])

        if self.clip_state:
            s = torch.clamp(s + self.dt * ds, 0.0, 1)
        else:
            s = s + self.dt * ds

        # Keep input clamped
        s[:, :self.num_input] = x

        return s, phi

    def compute_weight_updates(self, free_state, nudged_state, free_phi, nudged_phi, beta):
        dJ = torch.zeros_like(self.J)
        params = self.rule_params[self.learning_rule]

        for i in range(self.J.shape[0]):
            for j in range(self.J.shape[1]):
                if self.mask[i, j] != 0:  # Only update existing connections
                    if self.learning_rule == "simple_state":
                        dJ[i, j] = -((nudged_state[0, i] * nudged_state[0, j] - free_state[0, i] * free_state[0, j])) / beta
                    elif self.learning_rule == "simple_flux":
                        dJ[i, j] = -((nudged_phi[0, i] * nudged_phi[0, j] - nudged_phi[0, i] * nudged_phi[0, j])) / beta
                    elif self.learning_rule == "EqProp":
                        dJ[i, j] = -((self.gamma[i]/ beta) * (self.g(nudged_phi[0, i], nudged_state[0, i]) * self.g(nudged_phi[0, j], nudged_state[0, j]) -
                                                      self.g(free_phi[0, i], free_state[0, i]) * self.g(nudged_phi[0, j], free_state[0, j]))) 
                    elif self.learning_rule == "EqProp_var1":
                        dJ[i, j] = -((self.gamma[i]/ beta) * (self.g(nudged_phi[0, j], nudged_state[0, i]) * self.g(nudged_phi[0, j], nudged_state[0, i]) -
                                                      self.g(free_phi[0, j], free_state[0, i]) * self.g(free_phi[0, j], free_state[0, i]))) 
                    elif self.learning_rule == "SOEN1":
                        dJ[i, j] = -(self.gamma[i]/ beta) * (self.g(nudged_phi[0, i], nudged_state[0, i]) * nudged_state[0, j] -
                                                      self.g(free_phi[0, i], free_state[0, i]) * free_state[0, j]) 

        return dJ

    def update_weights(self, dJ):
        learning_rate = self.rule_params[self.learning_rule]["learning_rate"]
        with torch.no_grad():
            self.J.add_(learning_rate * dJ)
        self.J.data *= self.mask  # Ensure masked weights remain zero

    def set_learning_rule(self, rule):
        valid_rules = list(self.rule_params.keys())
        if rule in valid_rules:
            self.learning_rule = rule
        else:
            raise ValueError(f"Unknown learning rule: {rule}. Valid rules are: {', '.join(valid_rules)}")


def run_experiment(original_model, x, y, free_steps, nudged_steps, num_iterations, rule):
    model = copy.deepcopy(original_model)
    model.set_learning_rule(rule)
    current_state = None
    model.state_evolution = []
    mse_history = []

    for iteration in range(num_iterations):
        free_state, nudged_state, final_output, current_state, free_phi, nudged_phi = model(
            x, y, 
            free_steps=free_steps, 
            nudged_steps=nudged_steps, 
            beta=model.rule_params[rule].get("beta"),
            initial_state=current_state
        )

        dJ = model.compute_weight_updates(free_state, nudged_state, free_phi, nudged_phi, model.rule_params[rule].get("beta", 0.1))
        model.update_weights(dJ)

        mse = nn.MSELoss()(final_output, y).item()
        mse_history.append(mse)

    return mse_history

def save_results_to_csv(results, filename):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        header = ['Run', 'Rule', 'Iteration', 'MSE']
        writer.writerow(header)

        for run, run_results in enumerate(results):
            for rule, mse_history in run_results.items():
                for iteration, mse in enumerate(mse_history):
                    writer.writerow([run + 1, rule, iteration + 1, mse])

def plot_combined_mse(results, save_path):
    plt.figure(figsize=(12, 8), dpi=300)
    
    for rule in results[0].keys():
        rule_data = np.array([run_results[rule] for run_results in results])
        mean = np.mean(rule_data, axis=0)
        std = np.std(rule_data, axis=0)
        
        x = range(1, len(mean) + 1)
        plt.plot(x, mean, label=rule, linewidth=2)
        plt.fill_between(x, mean - std, mean + std, alpha=0.3)

    plt.xlabel('Iteration', fontsize=FONTSIZE)
    plt.ylabel('Mean Squared Error', fontsize=FONTSIZE)
    plt.yscale('log')
    plt.legend(fontsize=LABELSIZE)
    plt.tick_params(axis='both', which='major', labelsize=LABELSIZE)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, "combined_mse_plot.png"))
    plt.show()

if __name__ == "__main__":
    x = torch.tensor([[0.5]])
    y = torch.tensor([[0.05]])

    config = SOENConfig(
        num_input=1,
        num_hidden=10,
        num_output=1,
        p_input_hidden=1.0,
        p_hidden_output=1.0,
        p_input_input=1.0,
        p_hidden_hidden=1.0,
        p_output_output=1.0,
        allow_self_connections=False,
        allow_output_to_hidden_feedback=True,
        allow_hidden_to_input_feedback=True,
        allow_skip_connections=False,
        p_skip_connections=1.0,
        activation_function="NN_dendrite",
        dt=0.05,
        max_iter=1000,
        run_to_equilibrium=True,
        tol=1e-6,
        clip_phi=True, 
        clip_state=True,
        test_noise_std=0.05,
        bias_flux_offsets=True,
        enforce_symmetric_weights=True,
        weight_init_method="glorot",
        # init_scale=0.5,
    )

    free_steps = 200
    nudged_steps = 200
    num_iterations = 100
    num_runs = 1

    results = []

    for run in range(num_runs):
        print(f"\nRun {run + 1}/{num_runs}")
        original_model = SOENModelEP(config)
        original_model.eval()

        run_results = {}
        for rule in original_model.rule_params.keys():
            print(f"  Training with {rule} learning rule")
            run_results[rule] = run_experiment(original_model, x, y, free_steps, nudged_steps, num_iterations, rule)

        results.append(run_results)

    save_path = "/Users/matthewcox/Documents/UCL/MSc/Training_Apps/Results/ep"
    os.makedirs(save_path, exist_ok=True)
    
    csv_filename = os.path.join(save_path, "mse_results.csv")
    save_results_to_csv(results, csv_filename)
    print(f"Results saved to {csv_filename}")
