# XOR problem

## Mount google drive

In [None]:
import os
import sys

from google.colab import drive
drive.mount('/content/drive')

# make a simlink to save some clicks
DRIVE_PATH = '/content/drive/MyDrive/Stanford/Year2/TA/psych209/hw1/'
SYM_PATH = '/content/hw1'
if not os.path.exists(SYM_PATH):
    !ln -s $DRIVE_PATH $SYM_PATH

sys.path.append(DRIVE_PATH+'xor/') # for FFBP
print(sys.version)


## Import

In [None]:
import sys
print(sys.version) # ensure that you're using python3
import torch
print("PyTorch version = {}".format(torch.__version__))
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim

import numpy as np
import random
import matplotlib.pyplot as plt
# %matplotlib inline

import pickle
import copy

## Dataset

In [None]:
class Dataset(object):
    def __init__(self):
        self.train_data = {"inputs": [[0, 0], [0, 1], [1, 0], [1, 1]],
                           "targets": [[0], [1], [1], [0]],
                           "names": ["p00", "p01", "p10", "p11"]}
        self.train_size = 4
        self.test_data = self.train_data
        
    def reformat(self, x):
        return Variable(torch.FloatTensor(x), requires_grad=True)
    
    def get_per_epoch_batches(self, batch_size):
        assert self.train_size % batch_size == 0 # for this dataset, batches can be 1, 2, or 4
        train_item_inds = list(range(self.train_size))
        random.shuffle(train_item_inds)
        batches = []
        for batch_start in np.arange(0, self.train_size, batch_size):
            batch_inds = train_item_inds[batch_start : batch_start + batch_size]
            inputs = [self.train_data["inputs"][i] for i in batch_inds]
            targets = [self.train_data["targets"][i] for i in batch_inds]
            batches.append((self.reformat(inputs), self.reformat(targets)))
        return batches

## Model

In [None]:
class Model(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, 
                 hidden_nonlinearity="sigmoid", use_saved_params=False,
                 wrange=[-0.5, 0.5]):
        super(Model, self).__init__()
        # specify network layers
        self.linear_layers = nn.ModuleList(
                                [nn.Linear(input_size, hidden_size),
                                 nn.Linear(hidden_size, num_classes)])
        self.nonlinearities = [self.select_nonlinearity(hidden_nonlinearity), 
                               self.select_nonlinearity("sigmoid")] # final nonlinearity is always sigmoid
            
        # initialize weights
        for i, layer in enumerate(self.linear_layers):
            self.init_weights(i, layer, use_saved_params, wrange)

    def record_gnet(self, grad):
        self.gnet.append(grad)
        
    def record_gact(self, grad):
        self.gact.append(grad)
        
    def select_nonlinearity(self, nonlinearity):
        if nonlinearity == "sigmoid":
            return nn.Sigmoid()
        elif nonlinearity == "tanh":
            return nn.Tanh()
        elif nonlinearity == "relu":
            return nn.ReLU()
        
    def init_weights(self, i, layer, use_saved_params, wrange):
        if use_saved_params:
            if i == 0: # initial params in original paper
                layer.weight.data = torch.tensor([[0.432171, 0.448781],
                                                 [-0.038413, 0.036489]])
                layer.bias.data = torch.tensor([-0.27659, -0.4025])
            elif i == 1:
                layer.weight.data = torch.tensor([[0.27208, 0.081714]])
                layer.bias.data = torch.tensor([0.2793])
        else:
            layer.weight.data.uniform_(wrange[0], wrange[1]) # inplace
            layer.bias.data.uniform_(wrange[0], wrange[1])

    def forward(self, inp, record_data=True):
        if record_data:
            # for visualization purposes; not necessary to train the model 
            self.layer_inputs, self.layer_outputs, self.layer_activations = [], [], []
            self.gnet, self.gact = [], [] # these are recorded in backward order (the order of the gradient computations)
        
        out = inp
        for layer, nonlinearity in zip(self.linear_layers, self.nonlinearities):
            if record_data: # visualization purposes only
                self.layer_inputs.append(out)
                
            # feed through linear layer
            out = layer(out)
            
            if record_data:
                self.layer_outputs.append(out)
                out.register_hook(self.record_gnet) # dE/dnet
                
            # apply nonlinearity
            out = nonlinearity(out)
            
            if record_data:
                self.layer_activations.append(out)
                out.register_hook(self.record_gact) # dE/da
        return out           

## Trainer

In [None]:
# utils
def to_np(t):
    return t.data.numpy()

def make_dir(path):
    if not os.path.isdir(path):
        os.makedirs(path)
        
def get_logdir():
    results_dir = "logdirs"
    prefix = "logdir"
    fs = sorted([int(f.split("_")[1]) for f in os.listdir(results_dir) if prefix in f])
    if len(fs) == 0:
        num = 0
    else:
        num = fs[-1] + 1
    new_f = os.path.join(results_dir, prefix + "_" + str(num).zfill(3))
    return new_f

class Trainer(object):
    def __init__(self, dataset, model,
                 name="",
                 train_batch_size=4,
                 num_training_epochs=500, 
                 stopping_criterion=0.04,
                 learning_rate=0.25, momentum=0.9,
                 save_freq=1, test_freq=10,
                 checkpoint_freq=200, 
                 print_freq=0, show_plot=False,
                 save_dir=None):
        self.dataset = dataset
        self.model = model
        self.name = name
        
        # training parameters
        self.train_batch_size = train_batch_size
        assert self.dataset.train_size % self.train_batch_size == 0
        self.num_training_epochs = num_training_epochs
        self.stopping_criterion = stopping_criterion
        self.learning_rate = learning_rate
        self.momentum = momentum
        
        # loss function: sum of squared errors
        self.criterion = lambda predictions, targets: torch.sum((targets - predictions) ** 2)

        self.optimizer = optim.SGD(model.parameters(),
                                   lr=self.learning_rate, 
                                   momentum=self.momentum)
        
        # viewing & saving parameters
        self.save_dir = save_dir
        self.print_freq = print_freq
        self.save_freq = save_freq
        self.test_freq = test_freq
        self.checkpoint_freq = checkpoint_freq
        self.show_plot = show_plot
        self.setup()
        
    def setup(self):
        if self.save_dir is None:
            self.save_dir = os.path.join(get_logdir())
        self.checkpoints_dir = os.path.join(self.save_dir, "checkpoint_files_{}".format(self.name))
        [make_dir(p) for p in [self.save_dir, self.checkpoints_dir]]
        self.log_file = os.path.join(os.getcwd(), self.save_dir, "runlog_{}.pkl".format(self.name))
        self.checkpoint_file = os.path.join(self.checkpoints_dir, "checkpoint.pth".format(self.name))
        
        # load checkpoint if already exists
        if os.path.exists(self.checkpoint_file):
            print("Loading from checkpoint: {}".format(self.checkpoint_file))
            checkpoint = torch.load(self.checkpoint_file)
            self.start_epoch = checkpoint["epoch"] + 1
            self.model.load_state_dict(checkpoint["model_state"])
            self.optimizer.load_state_dict(checkpoint["optimizer_state"])
            self.load_results()
        else:
            self.start_epoch = 0
            self.loss_data = {"epochs": [], "train_losses": []}
            self.log = []
            
    def save_pickle(self, d, fname):
        with open(fname, 'wb') as f:
            pickle.dump(d, f)

    def load_pickle(self, fname):
        with open(fname, 'rb') as f:
            d = pickle.load(f)
        return d

    def save_results(self):
        info = {"test_data": self.log,
                "loss_data": self.loss_data}
        self.save_pickle(info, self.log_file)

    def load_results(self):
        f = self.load_pickle(self.log_file)
        self.log = f["test_data"]
        self.loss_data = f["loss_data"]
        
    def save_checkpoint(self, epoch):
        checkpoint = {"epoch": epoch,
                      "model_state": self.model.state_dict(),
                      "optimizer_state": self.optimizer.state_dict()}
        torch.save(checkpoint, self.checkpoint_file)
        
    def final_save_and_view(self, epoch, loss):
        self.eval_test_points(epoch)
        self.print_final()
        if self.show_plot:
            self.plot_loss()
        self.save_results()
        self.save_checkpoint(epoch)

    def init_info_dict(self, epoch, layer_names):
        info = {"enum": epoch,
                "input": to_np(self.dataset.reformat(
                                    self.dataset.test_data["inputs"])),
                "target": to_np(self.dataset.reformat(
                                    self.dataset.test_data["targets"])),
                "labels": self.dataset.test_data["names"],
                "loss_sum": 0., # summed loss over batch items
                "loss": np.zeros(len(self.dataset.test_data["inputs"]))} # losses each batch items (batch_size-long)
        
        for i in range(len(self.model.linear_layers)):
            layer_info = {"input_": None, # layer input (batch_sz x layer_inp_sz), as recorded in self.model.layer_inputs
                          "weights": copy.deepcopy(to_np(self.model.linear_layers[i].weight)),
                          "biases": copy.deepcopy(to_np(self.model.linear_layers[i].bias)),
                          "net": None, # output of linear layer, before nonlinearity (batch_sz x layer_output_sz), as recorded in self.model.layer_outputs
                          "act": None, # layer activations (batch_sz x layer_output_sz), as recorded in self.model.layer_activations
                          "gweights": None, # dE/dW (batch_sz x weight_sz[0] x weight_sz[1])
                          "gbiases": None, # dE/dB (batch_sz x bias_sz[0])
                          "gnet": None, #dE/dnet (batch_sz x layer_output_sz)
                          "gact": None, #dE/da (batch_sz x layer_output_sz)
                          "sgweights": None, # gweights summed across test items (weight_sz[0] x weight_sz[1])
                          "sgbiases": None} # gbiases summed across test items (bias_sz-long)
            info[layer_names[i]] = layer_info
        return info

    def eval_test_points(self, epoch, print_info=False):
        # record gradients w.r.t. each datapoint for visualization purposes only
        layer_names = ["layer{}".format(i) for i in range(len(self.model.linear_layers))]
        info = self.init_info_dict(epoch, layer_names)  
    
        # enumerate the test data so that we can view the gradients 
        # for each test point individually. Normally, you will want to feed
        # the test data through the model as a single batch.
        for i in range(len(self.dataset.test_data["inputs"])):
            test_input = self.dataset.reformat(
                            [self.dataset.test_data["inputs"][i]])
            test_target = self.dataset.reformat(
                            [self.dataset.test_data["targets"][i]])

            self.optimizer.zero_grad()
            test_prediction = self.model.forward(test_input, record_data=True)
            test_loss = self.criterion(test_prediction, test_target)

            # Here, we call the backward method in order to compute gradients for
            # visualization purposes. In your models, do NOT do this during test
            # evaluation! It can lead to training on the test set. 
            test_loss.backward()

            info["loss"][i] = to_np(test_loss)
            info["loss_sum"] += to_np(test_loss)
            
            # reverse to correspond to the forward direction
            self.model.gnet.reverse()
            self.model.gact.reverse()

            # update info dict
            for j in range(len(self.model.linear_layers)):
                layer_name = layer_names[j]
                
                def update(k, new_v, expand=False):
                    new_v = copy.deepcopy(to_np(new_v))
                    if expand == True:
                        new_v = np.expand_dims(new_v, 0)
                    if i == 0:
                        info[layer_name][k] = new_v
                    else:
                        info[layer_name][k] = np.vstack((info[layer_name][k], new_v))
                        
                update("input_", self.model.layer_inputs[j])
                update("net", self.model.layer_outputs[j])
                update("act", self.model.layer_activations[j])
                update("gweights", self.model.linear_layers[j].weight.grad, expand=True)
                update("gbiases", self.model.linear_layers[j].bias.grad, expand=True)
                update("gnet", self.model.gnet[j])
                update("gact", self.model.gact[j])
                    
        for i in range(len(self.model.linear_layers)):
            layer_name = layer_names[i]
            info[layer_name]["sgweights"] = np.sum(info[layer_name]["gweights"] , 0)
            info[layer_name]["sgbiases"] = np.sum(info[layer_name]["gbiases"] , 0)

        if print_info:
            print("\n------------------------")
            self.print_dict(info)
        self.log.append(info)

    def print_dict(self, d):
        for k, v in d.items():
            if isinstance(v, dict):
                print("\n{}".format(k))
                self.print_dict(v)
            else:
                print("\n{}".format(k))
                print(v)
                if isinstance(v, np.ndarray):
                    print(v.shape)
            
    def print_progress(self, epoch, loss):
        print('Epoch {}: {}'.format(epoch, loss))
        
    def print_final(self):
        out = "Run {}:".format(self.name)
        out += " Loss at epoch {}:".format(self.loss_data["epochs"][0])
        out += " {0:.4f};".format(self.loss_data["train_losses"][0])
        out += " Last epoch: {};".format(self.loss_data["epochs"][-1])
        out += " Loss on last epoch: {0:.4f}".format(self.loss_data["train_losses"][-1])
        print(out)

    def plot_loss(self):
        fig, ax = plt.subplots()
        ax.plot(self.loss_data["epochs"], self.loss_data["train_losses"])
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        max_loss = max(self.loss_data["train_losses"])
        ax.set_xlim([0, self.num_training_epochs])
        ax.set_ylim([0, max_loss + 0.1])
        ax.set_title('Training Progress')

    def train(self):
        if self.start_epoch == self.num_training_epochs - 1:
            # already trained
            return
        
        for epoch in range(self.start_epoch, self.num_training_epochs):
            # visualization purposes only
            if (epoch % self.test_freq) == 0:
                self.eval_test_points(epoch)
            
            # divide the training data into batches; 
            # we see each training item exactly once per epoch
            train_batches = self.dataset.get_per_epoch_batches(self.train_batch_size)
            
            loss_per_batch = []
            for batch in train_batches:
                # grab batch of training data
                inputs, targets = batch[0], batch[1]
            
                # zero the parameter gradients
                self.optimizer.zero_grad()

                # get model predictions given inputs
                predictions = self.model.forward(inputs, record_data=False)

                # compute loss
                loss = self.criterion(predictions, targets)
                loss_per_batch.append(loss) # record
                
                # update
                loss.backward()
                self.optimizer.step()
                
            # compute epoch loss (for visualization purposes)
            epoch_loss = sum(loss_per_batch) # add per-batch losses to get total loss this epoch
            
            # record
            self.loss_data["epochs"].append(epoch)
            self.loss_data["train_losses"].append(float(epoch_loss.data.numpy()))
            
            if self.print_freq > 0:
                if (epoch % self.print_freq) == 0:
                    self.print_progress(epoch, epoch_loss)
                
            if (epoch % self.save_freq) == 0:
                self.save_results()
                
            if (epoch % self.checkpoint_freq) == 0:
                self.save_checkpoint(epoch)
            
            if epoch_loss <= self.stopping_criterion:
                self.final_save_and_view(epoch, epoch_loss)
                return
            
            if epoch == self.num_training_epochs - 1:
                self.final_save_and_view(epoch, epoch_loss)
                
def train_multiple_runs(num_runs, 
                        hidden_size, 
                        hidden_nonlinearity,
                        use_saved_params, 
                        weight_initialization_range,
                        num_training_epochs, 
                        learning_rate, momentum,
                        stopping_criterion,
                        train_batch_size):     
    make_dir("logdirs")
    save_dir = get_logdir()
    print('Results for this runset saved in: {}'.format(save_dir))
    results = {r: {"epochs": None, "train_losses": None} for r in range(num_runs)}
    dataset = Dataset()
    
    for r in range(num_runs):
        model = Model(input_size=2, hidden_size=hidden_size, num_classes=1, 
                  hidden_nonlinearity=hidden_nonlinearity, 
                  use_saved_params=use_saved_params,
                  wrange=weight_initialization_range)
        
        trainer = Trainer(dataset, model, 
                          name="{}".format(r),
                          train_batch_size=train_batch_size,
                          num_training_epochs=num_training_epochs,
                          stopping_criterion=stopping_criterion,
                          learning_rate=learning_rate, 
                          momentum=momentum, 
                          save_dir=save_dir)
        trainer.train()

        results[r]["epochs"] = trainer.loss_data["epochs"]
        results[r]["train_losses"] = trainer.loss_data["train_losses"]
    return results

## Visualization

In [None]:
def compute_mean_loss_across_runs(results, num_runs, num_training_epochs):
    mean_train_losses = np.zeros(num_training_epochs)
    
    def pad_with_final_loss(epochs, losses):
        # accounts for early stopping on some runs
        last_recorded_epoch = epochs[-1]
        padding = losses[-1] * np.ones(num_training_epochs - (last_recorded_epoch + 1))
        return np.concatenate((losses, padding))
    
    for r in range(num_runs):
        losses_this_run = results[r]["train_losses"]
        if len(losses_this_run) < num_training_epochs:
            losses_this_run = pad_with_final_loss(results[r]["epochs"], 
                                                  losses_this_run)
        mean_train_losses += losses_this_run
    return mean_train_losses/num_runs

def plot_results_by_run(num_runs, num_training_epochs,
                        results, mean_training_loss_across_runs):
    fig, ax = plt.subplots(figsize=[12,8]) #width and height of plot – change to suit your preference
    for r in range(num_runs):
        label = "Run {}".format(r)
        if r == 0:
            max_loss_across_runs = max(results[r]["train_losses"])
        else:
            max_loss_across_runs = max(max_loss_across_runs, 
                                       max(results[r]["train_losses"]))
        ax.plot(results[r]["epochs"], results[r]["train_losses"], label=label)
        
    # add mean across runs
    ax.plot(range(num_training_epochs), mean_training_loss_across_runs, 
            '--', label="mean")
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_ylim([0., max_loss_across_runs])
    ax.legend()
    ax.set_title('Results by Run')
    plt.show()

## Run

In [None]:
# run a single runset

def main():
    num_runs = 2 # 1 for Ex. 1; 8 for Ex 2.1; 10 for Ex. 2.2
    
    # PARAMETERS 
    # training params - don't change these
    num_training_epochs = 500
    stopping_criterion = 0.04
    
    # optimization params
    learning_rate = 0.25 # 0.25
    momentum = 0.9 # 0.9
    
    # initialization/scheduling params
    use_saved_params = True # set to True for Exercise 1, then False for Exercise 2
    weight_initialization_range = [-0.5, 0.5] # [-0.5, 0.5] used to initialize weights if used_save_params is False
    if use_saved_params:
        weight_initialization_range = None
    train_batch_size = 4 # 4
    
    # model params
    hidden_size = 2 # 2
    hidden_nonlinearity = "sigmoid" # "sigmoid"; other options: "relu" or "tanh"    

    # TRAIN (1 RUNSET)
    results_by_run = train_multiple_runs(num_runs, 
                                         hidden_size, hidden_nonlinearity,
                                         use_saved_params, 
                                         weight_initialization_range,
                                         num_training_epochs, 
                                         learning_rate, momentum,
                                         stopping_criterion,
                                         train_batch_size)

    # VISUALIZE
    mean_training_loss_across_runs = compute_mean_loss_across_runs(
                                         results_by_run, num_runs, 
                                         num_training_epochs)

    plot_results_by_run(num_runs, num_training_epochs,
                        results_by_run, mean_training_loss_across_runs)
    
main()

# Visualize XOR

In [None]:
from FFBP.vis_utils import view_layers, view_layers_colab
import numpy as np

%matplotlib inline

def logistic(net):
    return np.divide(1, (1 + np.exp(-net)))

In [None]:
logdir = 'logdirs/logdir_000' # change the digits to visualize results in the desired log

view_layers_colab(
    logdir = logdir, 
    mode = 2,
    show_values=True
)