### Shared Embedding model
- note that you have to change the model initialization in the Training Manager

In [None]:
class LELNetwork(nn.Module):
  def __init__(self, h_size=256, dropout=0):
    super().__init__()
    
    self.shared_base = nn.Linear(768, h_size)
    self.LEL = nn.Linear(h_size, OUT_DIMS['ekman']+OUT_DIMS['vad']+OUT_DIMS['sem'], bias=False)

    self.dropout = nn.Dropout(p=dropout, inplace=False)
    self.relu = nn.ReLU()
    self.softmax = nn.LogSoftmax(dim=1)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x, task):  

    x = self.relu(self.shared_base(x))
    x = self.dropout(x)
    y = self.LEL(x)
    ekman_filter = task[:, 0].unsqueeze(-1)
    vad_filter = task[:, 1].unsqueeze(-1)
    sem_filter = task[:, 2].unsqueeze(-1)

    y[:, 0:5] = self.softmax(y[:, 0:5] * ekman_filter)
    y[:, 5:8] = self.relu(y[:, 5:8] * vad_filter)
    y[:,8:] = self.sigmoid(y[:, 8:] * sem_filter)

    return y

### Plot correlation matrix

In [None]:
def plot_LEL(net):
  np.random.seed(42)
  # Get the Label Embedding Layer weights
  L = net.LEL.weight.data.cpu().numpy()

  # Calculate the correlation matrix
  corr_matrix = np.corrcoef(L.T, rowvar=False)
  
  #Remove VAD from matrix
  # indices_to_remove = np.arange(5, 8) 
  # Remove the specified rows
  # corr_matrix = np.delete(corr_matrix, indices_to_remove, axis=0)
  # Remove the specified columns
  # corr_matrix = np.delete(corr_matrix, indices_to_remove, axis=1)
  
  labels = ['anger-disgust (Ek)', 'fear (Ek)', 'happy (Ek)', 'sad (Ek)', 'surprise (Ek)', 'V', 'A', 'D', 'anger (S)', 'anticipation (S)', 'disgust (S)', 'fear (S)', 'joy (S)', 'love (S)', 'optimism (S)', 'pessimism (S)', 'sadness (S)', 'surprise (S)', 'trust (S)']
 
  fig, ax = plt.subplots(figsize=(8, 8))
  # Create a heatmap of the correlation matrix
  im = ax.imshow(corr_matrix, cmap='magma', vmin=-1, vmax=1, alpha=0.5)
  # Add a colorbar
  cbar = ax.figure.colorbar(im, ax=ax, shrink=0.8)
  # Set the ticks and ticklabels for the x-axis and y-axis
  ax.set_xticks(np.arange(corr_matrix.shape[0]))
  ax.set_yticks(np.arange(corr_matrix.shape[1]))
  ax.set_xticklabels(labels)
  ax.set_yticklabels(labels)
  # Rotate the ticklabels and set them to center alignment
  plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

  # Loop over data dimensions and create text annotations
  for i in range(corr_matrix.shape[0]):
      for j in range(corr_matrix.shape[1]):
          text = ax.text(j, i, np.round(corr_matrix[i, j], 1),
                        ha="center", va="center", color="black", fontsize=8)
          
  # Set the title of the plot
  ax.set_title("Label Correlation Matrix")

  # Adjust the layout
  plt.tight_layout()

  # save the matrix to a CSV file with headers
  # np.savetxt('corr_matrix.csv', corr_matrix, delimiter=',', header=','.join(labels), comments='', fmt='%d', 
            # footer='\n'.join(labels))

  plt.savefig('corr_matrix.png', dpi=300)

  # Show the plot
  plt.show()