# W2D4 Tutorial1

##  Install dependencies


In [None]:
# @title Install dependencies
# !pip install numpy matplotlib Pillow torch torchvision transformers ipywidgets gradio trdg scikit-learn networkx

##  Import Dependencies


In [None]:
# @title Import Dependencies

# Standard Libraries for basic operations and file handling
import os
import random
import hashlib
import requests

# Core Data Science Libraries for numerical computations and data manipulation
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Image Processing Libraries for handling and manipulating image data
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms

# Deep Learning Libraries for model building, training, and evaluation
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

# Utility Libraries for creating interactive elements and interfaces
import ipywidgets as widgets
import gradio as gr
from IPython.display import IFrame

# Libraries for Data Generation and Graph Analysis
from trdg.generators import GeneratorFromStrings
import networkx as nx

In [None]:
#A feedforward network has a structured layout with layers. Each layer consists of a certain number of neurons,
# and each neuron in one layer is connected to all neurons in the next layer

# We will simulate the activation of neurons by selecting a neuron randomly and then propagating this activation
# forward, to mimic the feedforward process

class SimpleFNNModel:
    def __init__(self, layer_sizes=[5, 3, 2]):
        self.layer_sizes = layer_sizes
        self.activations = [[False for _ in range(size)] for size in layer_sizes]

    def activate_neurons(self):
        # Randomly activate neurons in the input layer
        self.activations[0][np.random.randint(0, self.layer_sizes[0])] = True

        # Propagate activation to subsequent layers
        for i in range(1, len(self.layer_sizes)):
            for j in range(self.layer_sizes[i]):
                if any(self.activations[i-1]):
                    self.activations[i][j] = True

    def reset_activations(self):
        self.activations = [[False for _ in range(size)] for size in self.layer_sizes]

    def draw_network(self):
        plt.figure(figsize=(10, 6))
        for layer_index, layer_size in enumerate(self.layer_sizes):
            # Calculate positions
            x = np.full(layer_size, layer_index)
            y = np.linspace(0, 1, layer_size)
            colors = ['green' if act else 'red' for act in self.activations[layer_index]]

            # Draw neurons
            plt.scatter(x, y, c=colors, s=100, label=f'Layer {layer_index}')

            # Draw connections
            if layer_index > 0:
                prev_layer_size = self.layer_sizes[layer_index - 1]
                for prev_neuron_index in range(prev_layer_size):
                    for neuron_index in range(layer_size):
                        plt.plot([layer_index - 1, layer_index], [np.linspace(0, 1, prev_layer_size)[prev_neuron_index], y[neuron_index]], 'gray')

        plt.legend()
        plt.axis('off')
        plt.show()

# Create FNN model instance
fnn_model = SimpleFNNModel()

# Define UI elements
activate_button = widgets.Button(description='Activate Neurons')
reset_button = widgets.Button(description='Reset')
output_area = widgets.Output()

# Define button click actions
def on_activate_clicked(b):
    with output_area:
        output_area.clear_output(wait=True)
        fnn_model.activate_neurons()
        fnn_model.draw_network()

def on_reset_clicked(b):
    with output_area:
        output_area.clear_output(wait=True)
        fnn_model.reset_activations()
        fnn_model.draw_network()

# Set up button callbacks
activate_button.on_click(on_activate_clicked)
reset_button.on_click(on_reset_clicked)

# Display UI
display(widgets.VBox([activate_button, reset_button, output_area]))

In [None]:
class SimpleGNWModel:
    def __init__(self, num_nodes=5):
        self.num_nodes = num_nodes
        self.network = nx.erdos_renyi_graph(n=num_nodes, p=0.5)
        self.activations = {node: False for node in self.network.nodes}

    def activate_node(self):
        selected_node = random.choice(list(self.network.nodes))
        self.activations[selected_node] = True

        # Simulate global broadcast
        for neighbor in self.network.neighbors(selected_node):
            self.activations[neighbor] = True

    def reset_activations(self):
        self.activations = {node: False for node in self.network.nodes}

    def draw_network(self):
        color_map = ['green' if self.activations[node] else 'red' for node in self.network.nodes]
        nx.draw(self.network, node_color=color_map, with_labels=True, node_size=700)
        plt.show()

# Create a GNW model instance
gnw_model = SimpleGNWModel()

# Button to activate a node
activate_button = widgets.Button(description='Activate Node')

# Button to reset activations
reset_button = widgets.Button(description='Reset')

# Output area for the network graph
output_area = widgets.Output()

def on_activate_clicked(b):
    with output_area:
        output_area.clear_output(wait=True)
        gnw_model.activate_node()
        gnw_model.draw_network()

def on_reset_clicked(b):
    with output_area:
        output_area.clear_output(wait=True)
        gnw_model.reset_activations()
        gnw_model.draw_network()

activate_button.on_click(on_activate_clicked)
reset_button.on_click(on_reset_clicked)

display(widgets.VBox([activate_button, reset_button, output_area]))

In [None]:
# Set fixed random seed
torch.manual_seed(0)
np.random.seed(0)

# Generate a synthetic classification dataset with 5000 samples, 20 features, 2 classes, and 15 informative features
X, y = make_classification(n_samples=5000, n_features=20, n_classes=2, n_informative=17, random_state=42)
# Split the dataset into training and testing sets with 80% training and 20% testing
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

noise_factor = 0  # Define the magnitude of the noise
noise = np.random.randn(*X_train.shape) * noise_factor  # Generate Gaussian noise
X_train = X_train + noise  # Add noise to the training data

# Initialize the standard scaler
scaler = StandardScaler()
# Fit the scaler on the training data and transform it
X_train = scaler.fit_transform(X_train)
# Transform the test data with the same scaler
X_test = scaler.fit_transform(X_test)

# Convert arrays into PyTorch tensors for training and testing sets
X_train_torch = torch.FloatTensor(X_train)
X_test_torch = torch.FloatTensor(X_test)
y_train_torch = torch.LongTensor(y_train)
y_test_torch = torch.LongTensor(y_test)

# Create TensorDataset from tensors
train_dataset = TensorDataset(X_train_torch, y_train_torch)
test_dataset = TensorDataset(X_test_torch, y_test_torch)

# Create DataLoader for both training and testing datasets to iterate over batches
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

class FeedforwardNet(nn.Module):
    def __init__(self):
        super(FeedforwardNet, self).__init__()
        self.fc1 = nn.Linear(20, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

class GNWNet(nn.Module):
    def __init__(self):
        super(GNWNet, self).__init__()
        self.fc1 = nn.Linear(20, 64)
        self.attention = nn.Linear(64, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 2)

    def forward(self, x):
        x = self.fc1(x)
        attention_weights = torch.sigmoid(self.attention(x))
        x = x * attention_weights
        x = self.relu(x)
        x = self.fc2(x)
        return x

def train_model(model, train_loader, criterion, optimizer):
    model.train()
    total_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    average_loss = total_loss / len(train_loader)
    return average_loss

def evaluate_model(model, test_loader, criterion):
    model.eval()
    total_loss = 0.0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    average_loss = total_loss / len(test_loader)
    accuracy = 100. * correct / len(test_loader.dataset)
    return average_loss, accuracy

# Lists to store
losses_basic_train = []
losses_gnw_train = []
losses_basic_val = []
losses_gnw_val = []
accuracy_basic_net = []
accuracy_gnw_net = []

basic_net = FeedforwardNet()
gnw_net = GNWNet()
criterion = nn.CrossEntropyLoss()
optimizer_basic = optim.Adam(basic_net.parameters(), lr=0.00001)
optimizer_gnw = optim.Adam(gnw_net.parameters(), lr=0.00001)

patience = 10  # Number of epochs to wait
best_val_loss_basic = float('inf')  # Initialize the best validation loss for basic_net
epochs_no_improve_basic = 0  # Counter for epochs with no improvement for basic_net

# Training and validation for FeedforwardNet
for epoch in range(3000):
    train_loss = train_model(basic_net, train_loader, criterion, optimizer_basic)
    val_loss, accuracy = evaluate_model(basic_net, test_loader, criterion)
    losses_basic_train.append(train_loss)
    losses_basic_val.append(val_loss)
    accuracy_basic_net.append(accuracy)

    # Check if the validation loss improved
    if val_loss < best_val_loss_basic:
        best_val_loss_basic = val_loss
        epochs_no_improve_basic = 0
    else:
        epochs_no_improve_basic += 1

    # Early stopping check
    if epochs_no_improve_basic >= patience:
        print(f"Early stopping triggered for FeedforwardNet at epoch {epoch}")
        break

best_val_loss_gnw = float('inf')  # Initialize the best validation loss for gnw_net
epochs_no_improve_gnw = 0  # Counter

# Training and validation for GNWNet with early stopping
for epoch in range(1000):
    train_loss = train_model(gnw_net, train_loader, criterion, optimizer_gnw)
    val_loss, accuracy = evaluate_model(gnw_net, test_loader, criterion)
    losses_gnw_train.append(train_loss)
    losses_gnw_val.append(val_loss)
    accuracy_gnw_net.append(accuracy)

    # Check if the validation loss improved
    if val_loss < best_val_loss_gnw:
        best_val_loss_gnw = val_loss
        epochs_no_improve_gnw = 0
    else:
        epochs_no_improve_gnw += 1

    # Early stopping check
    if epochs_no_improve_gnw >= patience:
        print(f"Early stopping triggered for GNWNet at epoch {epoch}")
        break

# Plot for FeedforwardNet
plt.figure(figsize=(10, 6))
plt.plot(losses_basic_train, label='Training Loss')
plt.plot(losses_basic_val, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('FeedforwardNetwork Training vs Validation Loss')
plt.legend()
plt.show()

# Plot for GNWNet
plt.figure(figsize=(10, 6))
plt.plot(losses_gnw_train, label='Training Loss')
plt.plot(losses_gnw_val, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('GNWNetwork Training vs Validation Loss')
plt.legend()
plt.show()

In [None]:
import torch
import torch.nn as nn
import math
import numpy as np
import torch.multiprocessing as mp

class blocked_grad(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, mask):
        ctx.save_for_backward(x, mask)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        x, mask = ctx.saved_tensors
        return grad_output * mask, mask * 0.0

class GroupLinearLayer(nn.Module):

    def __init__(self, din, dout, num_blocks):
        super(GroupLinearLayer, self).__init__()

        self.w	 = nn.Parameter(0.01 * torch.randn(num_blocks,din,dout))

    def forward(self,x):
        x = x.permute(1,0,2)

        x = torch.bmm(x,self.w)
        return x.permute(1,0,2)

class GroupLSTMCell(nn.Module):
	"""
	GroupLSTMCell can compute the operation of N LSTM Cells at once.
	"""
	def __init__(self, inp_size, hidden_size, num_lstms):
		super().__init__()
		self.inp_size = inp_size
		self.hidden_size = hidden_size

		self.i2h = GroupLinearLayer(inp_size, 4 * hidden_size, num_lstms)
		self.h2h = GroupLinearLayer(hidden_size, 4 * hidden_size, num_lstms)
		self.reset_parameters()

	def reset_parameters(self):
		stdv = 1.0 / math.sqrt(self.hidden_size)
		for weight in self.parameters():
			weight.data.uniform_(-stdv, stdv)

	def forward(self, x, hid_state):
		"""
		input: x (batch_size, num_lstms, input_size)
			   hid_state (tuple of length 2 with each element of size (batch_size, num_lstms, hidden_state))
		output: h (batch_size, num_lstms, hidden_state)
				c ((batch_size, num_lstms, hidden_state))
		"""
		h, c = hid_state
		preact = self.i2h(x) + self.h2h(h)

		gates = preact[:, :,  :3 * self.hidden_size].sigmoid()
		g_t = preact[:, :,  3 * self.hidden_size:].tanh()
		i_t = gates[:, :,  :self.hidden_size]
		f_t = gates[:, :, self.hidden_size:2 * self.hidden_size]
		o_t = gates[:, :, -self.hidden_size:]

		c_t = torch.mul(c, f_t) + torch.mul(i_t, g_t)
		h_t = torch.mul(o_t, c_t.tanh())

		return h_t, c_t

class GNWCell(nn.Module):
    def __init__(self,
                 device, input_size, hidden_size, num_units, k, rnn_cell, input_key_size=64, input_value_size=400, input_query_size=64,
                 num_input_heads=1, input_dropout=0.1, comm_key_size=32, comm_value_size=100, comm_query_size=32, num_comm_heads=4, comm_dropout=0.1):
        super().__init__()
        if comm_value_size != hidden_size:
            # print('INFO: Changing communication value size to match hidden_size')
            comm_value_size = hidden_size
        self.device = device
        self.hidden_size = hidden_size
        self.num_units = num_units
        self.rnn_cell = rnn_cell
        self.key_size = input_key_size
        self.k = k
        self.num_input_heads = num_input_heads
        self.num_comm_heads = num_comm_heads
        self.input_key_size = input_key_size
        self.input_query_size = input_query_size
        self.input_value_size = input_value_size

        self.comm_key_size = comm_key_size
        self.comm_query_size = comm_query_size
        self.comm_value_size = comm_value_size

        self.key = nn.Linear(input_size, num_input_heads * input_query_size).to(self.device)
        self.value = nn.Linear(input_size, num_input_heads * input_value_size).to(self.device)
        self.global_workspace = nn.Parameter(torch.zeros(1, self.hidden_size), requires_grad=False)
        self.broadcast_gate = nn.Linear(self.hidden_size, self.num_units)
        self.rnn = GroupLSTMCell(input_value_size, hidden_size, num_units)
        self.query = GroupLinearLayer(hidden_size, input_key_size * num_input_heads, self.num_units)
        self.query_ = GroupLinearLayer(hidden_size, comm_query_size * num_comm_heads, self.num_units)
        self.key_ = GroupLinearLayer(hidden_size, comm_key_size * num_comm_heads, self.num_units)
        self.value_ = GroupLinearLayer(hidden_size, comm_value_size * num_comm_heads, self.num_units)
        self.comm_attention_output = GroupLinearLayer(num_comm_heads * comm_value_size, comm_value_size, self.num_units)
        self.comm_dropout = nn.Dropout(p=input_dropout)
        self.input_dropout = nn.Dropout(p=comm_dropout)

    def update_global_workspace(self, hs):
        # Compute the importance of each unit's information
        gate_outputs = torch.sigmoid(self.broadcast_gate(hs))
        # Decide which units to broadcast based on gate output
        broadcast_mask = (gate_outputs > 0.5).float()
        # Broadcasting information weighted by the gate's output
        weighted_hs = hs * broadcast_mask.unsqueeze(-1)
        broadcast_info = torch.sum(weighted_hs, dim=1, keepdim=True) / torch.sum(broadcast_mask, dim=1, keepdim=True).clamp(min=1)
        self.global_workspace = broadcast_info

    def retrieve_global_workspace_info(self):
        return self.global_workspace.expand(self.num_units, -1, -1)

    def transpose_for_scores(self, x, num_attention_heads, attention_head_size):
        new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def input_attention_mask(self, x, h):
        """
        Input : x (batch_size, 2, input_size) [The null input is appended along the first dimension]
                h (batch_size, num_units, hidden_size)
        Output: inputs (list of size num_units with each element of shape (batch_size, input_value_size))
                mask_ binary array of shape (batch_size, num_units) where 1 indicates active and 0 indicates inactive
        """
        key_layer = self.key(x)
        value_layer = self.value(x)
        query_layer = self.query(h)

        key_layer = self.transpose_for_scores(key_layer, self.num_input_heads, self.input_key_size)
        value_layer = torch.mean(self.transpose_for_scores(value_layer, self.num_input_heads, self.input_value_size), dim=1)
        query_layer = self.transpose_for_scores(query_layer, self.num_input_heads, self.input_query_size)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) / math.sqrt(self.input_key_size)
        attention_scores = torch.mean(attention_scores, dim=1)
        mask_ = torch.zeros(x.size(0), self.num_units).to(self.device)

        not_null_scores = attention_scores[:, :, 0]
        topk1 = torch.topk(not_null_scores, self.k, dim=1)
        row_index = np.arange(x.size(0))
        row_index = np.repeat(row_index, self.k)

        mask_[row_index, topk1.indices.view(-1)] = 1

        attention_probs = self.input_dropout(nn.Softmax(dim=-1)(attention_scores))
        inputs = torch.matmul(attention_probs, value_layer) * mask_.unsqueeze(2)
        return inputs, mask_

    def communication_attention(self, h, mask):
        """
        Input : h (batch_size, num_units, hidden_size)
                mask obtained from the input_attention_mask() function
        Output: context_layer (batch_size, num_units, hidden_size). New hidden states after communication
        """
        query_layer = self.query_(h)
        key_layer = self.key_(h)
        value_layer = self.value_(h)

        query_layer = self.transpose_for_scores(query_layer, self.num_comm_heads, self.comm_query_size)
        key_layer = self.transpose_for_scores(key_layer, self.num_comm_heads, self.comm_key_size)
        value_layer = self.transpose_for_scores(value_layer, self.num_comm_heads, self.comm_value_size)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.comm_key_size)

        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        mask = [mask for _ in range(attention_probs.size(1))]
        mask = torch.stack(mask, dim=1)

        attention_probs = attention_probs * mask.unsqueeze(3)
        attention_probs = self.comm_dropout(attention_probs)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.num_comm_heads * self.comm_value_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        context_layer = self.comm_attention_output(context_layer)
        context_layer = context_layer + h
        return context_layer

    def forward(self, x, hs, cs=None):
        """
        Input : x (batch_size, 1, input_size)
                hs (batch_size, num_units, hidden_size)
                cs (batch_size, num_units, hidden_size)
        Output: new hs, cs for LSTM
        """
        size = x.size()
        null_input = torch.zeros(size[0], 1, size[2]).float().to(self.device)
        x = torch.cat((x, null_input), dim=1)

        # Compute input attention
        inputs, mask = self.input_attention_mask(x, hs)
        h_old = hs * 1.0
        if cs is not None:
            c_old = cs * 1.0

        # Update the global workspace with information from units
        self.update_global_workspace(hs)

        # Integrate global workspace information
        global_info = self.retrieve_global_workspace_info()

        # Integrate the global workspace information with the inputs
        inputs_with_global_info = torch.cat((inputs, global_info.repeat(1, inputs.shape[1], 1)), dim=2)

        # Compute RNN(LSTM) output with integrated global workspace information
        hs, cs = self.rnn(inputs_with_global_info, (hs, cs))

        # Block gradient through inactive units
        mask = mask.unsqueeze(2)
        h_new = blocked_grad.apply(hs, mask)

        # Compute communication attention
        h_new = self.communication_attention(h_new, mask.squeeze(2))

        hs = mask * h_new + (1 - mask) * h_old
        if cs is not None:
            cs = mask * cs + (1 - mask) * c_old
            return hs, cs
        return hs, None


class RIM(nn.Module):
	def __init__(self, device, input_size, hidden_size, num_units, k, rnn_cell, n_layers, bidirectional, **kwargs):
		super().__init__()
		if device == 'cuda':
			self.device = torch.device('cuda')
		else:
			self.device = torch.device('cpu')
		self.n_layers = n_layers
		self.num_directions = 2 if bidirectional else 1
		self.rnn_cell = rnn_cell
		self.num_units = num_units
		self.hidden_size = hidden_size
		if self.num_directions == 2:
			self.GNWCell = nn.ModuleList([GNWCell(self.device, input_size, hidden_size, num_units, k, rnn_cell, **kwargs).to(self.device) if i < 2 else
				GNWCell(self.device, 2 * hidden_size * self.num_units, hidden_size, num_units, k, rnn_cell, **kwargs).to(self.device) for i in range(self.n_layers * self.num_directions)])
		else:
			self.GNWCell = nn.ModuleList([GNWCell(self.device, input_size, hidden_size, num_units, k, rnn_cell, **kwargs).to(self.device) if i == 0 else
			GNWCell(self.device, hidden_size * self.num_units, hidden_size, num_units, k, rnn_cell, **kwargs).to(self.device) for i in range(self.n_layers)])

	def layer(self, rim_layer, x, h, c = None, direction = 0):
		batch_size = x.size(1)
		xs = list(torch.split(x, 1, dim = 0))
		if direction == 1: xs.reverse()
		hs = h.squeeze(0).view(batch_size, self.num_units, -1)
		cs = None
		if c is not None:
			cs = c.squeeze(0).view(batch_size, self.num_units, -1)
		outputs = []
		for x in xs:
			x = x.squeeze(0)
			hs, cs = rim_layer(x.unsqueeze(1), hs, cs)
			outputs.append(hs.view(1, batch_size, -1))
		if direction == 1: outputs.reverse()
		outputs = torch.cat(outputs, dim = 0)
		if c is not None:
			return outputs, hs.view(batch_size, -1), cs.view(batch_size, -1)
		else:
			return outputs, hs.view(batch_size, -1)

	def forward(self, x, h = None, c = None):
		"""
		Input: x (seq_len, batch_size, feature_size
			   h (num_layers * num_directions, batch_size, hidden_size * num_units)
			   c (num_layers * num_directions, batch_size, hidden_size * num_units)
		Output: outputs (batch_size, seqlen, hidden_size * num_units * num-directions)
				h(and c) (num_layer * num_directions, batch_size, hidden_size* num_units)
		"""

		hs = torch.split(h, 1, 0) if h is not None else torch.split(torch.randn(self.n_layers * self.num_directions, x.size(1), self.hidden_size * self.num_units).to(self.device), 1, 0)
		hs = list(hs)
		cs = None
		cs = torch.split(c, 1, 0) if c is not None else torch.split(torch.randn(self.n_layers * self.num_directions, x.size(1), self.hidden_size * self.num_units).to(self.device), 1, 0)
		cs = list(cs)
		for n in range(self.n_layers):
			idx = n * self.num_directions
			if cs is not None:
				x_fw, hs[idx], cs[idx] = self.layer(self.GNWCell[idx], x, hs[idx], cs[idx])
			else:
				x_fw, hs[idx] = self.layer(self.GNWCell[idx], x, hs[idx], c = None)
			if self.num_directions == 2:
				idx = n * self.num_directions + 1
				if cs is not None:
					x_bw, hs[idx], cs[idx] = self.layer(self.GNWCell[idx], x, hs[idx], cs[idx], direction = 1)
				else:
					x_bw, hs[idx] = self.layer(self.GNWCell[idx], x, hs[idx], c = None, direction = 1)

				x = torch.cat((x_fw, x_bw), dim = 2)
			else:
				x = x_fw
		hs = torch.stack(hs, dim = 0)
		if cs is not None:
			cs = torch.stack(cs, dim = 0)
			return x, hs, cs
		return x, hs