<a href="https://colab.research.google.com/github/PaoloGerosa/Chemotherapy-Associated-Liver-Injury/blob/main/Graph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import plotly.graph_objects as go
import numpy as np
import plotly.express as px
import pandas as pd 
from plotly.subplots import make_subplots
import plotly.figure_factory as ff
import numpy as np
import Clustering 


In [None]:
import plotly.graph_objects as go
import numpy as np

def plot_3D(*args):
  if len(args) == 2:
    X, Y, Z = np.mgrid[-1:1:10j, -1:1:10j, -1:1:30j]
    values = args[0]

    fig = go.Figure(data=go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=values.flatten(),
        isomin=np.mean(values)-2*np.std(values),
        isomax=np.mean(values)+2*np.std(values),
        opacity=0.2, # needs to be small to see through all surfaces
        surface_count=20, # needs to be a large number for good volume rendering
        ))
    fig.update_layout(
          title= args[1]
    )
    fig.show()
  elif len(args) == 4:
    X, Y, Z = np.mgrid[-15:15:30j, -5:5:10j, -5:5:10j]
    
    values, isominimum, isomaximum = args[0:3]

    fig = go.Figure(data=go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=values.flatten(),
        isomin=isominimum,
        isomax=isomaximum,
        #colorscale = 'gray',
        opacity=0.1, # needs to be small to see through all surfaces
        surface_count=17, # needs to be a large number for good volume rendering
        ))
    fig.update_layout(
          title= args[3]
    )
    fig.show()

In [None]:
def plot_slices(volume):
  r, c = volume[0].shape

  # Define frames
  nb_frames = 30

  fig = go.Figure(frames=[go.Frame(data=go.Surface(
      z=(29 - k) * np.ones((r, c)),
      surfacecolor=np.flipud(volume[29 - k]),
      cmin=np.mean(volume)-2*np.std(volume), cmax=np.mean(volume)+2*np.std(volume)
      ),
      name=str(k) # you need to name the frame for the animation to behave properly
      )
      for k in range(nb_frames)])

  # Add data to be displayed before animation starts
  fig.add_trace(go.Surface(
      z=29 * np.ones((r, c)),
      surfacecolor=np.flipud(volume[29]),
      cmin=np.mean(volume)-2*np.std(volume), cmax=np.mean(volume)+2*np.std(volume), opacity = 0.8,
      colorbar=dict(thickness=20, ticklen=4)
      ))


  def frame_args(duration):
      return {
              "frame": {"duration": duration},
              "mode": "immediate",
              "fromcurrent": True,
              "transition": {"duration": duration, "easing": "linear"},
          }

  sliders = [
              {
                  "pad": {"b": 10, "t": 60},
                  "len": 0.9,
                  "x": 0.1,
                  "y": 0,
                  "steps": [
                      {
                          "args": [[f.name], frame_args(0)],
                          "label": str(k),
                          "method": "animate",
                      }
                      for k, f in enumerate(fig.frames)
                  ],
              }
          ]

  # Layout
  fig.update_layout(
          title='Slices in a volumetric liver image',
          width=600,
          height=600,
          scene=dict(
                      zaxis=dict(range=[-0.1, 29], autorange=False),
                      aspectratio=dict(x=1, y=1, z=1),
                      ),
          updatemenus = [
              {
                  "buttons": [
                      {
                          "args": [None, frame_args(150)],
                          "label": "&#9654;", # play symbol
                          "method": "animate",
                      },
                      {
                          "args": [[None], frame_args(0)],
                          "label": "&#9724;", # pause symbol
                          "method": "animate",
                      },
                  ],
                  "direction": "left",
                  "pad": {"r": 10, "t": 70},
                  "type": "buttons",
                  "x": 0.1,
                  "y": 0,
              }
          ],
          sliders=sliders
  )

  fig.show()

In [None]:
import matplotlib.pyplot as plt

def plot_hist(hist):
  plt.plot(hist[0],hist[1][:-1])
  plt.show()

# **Graphics for Clustering**

In [None]:
def visualize(dataset, dist, name_alg, k):
  labels = Clustering.clust_methods(dist, name_alg, k)
  matrix = Clustering.clustering_score(dataset, labels)
  mat = Clustering.confusion_matrix(dataset,labels)
  matrix.drop('CALI', axis = 1, inplace = True)
  clust_bar(matrix,k)
  clust_pie(matrix,k)
  confusion_heatmap(mat)

In [None]:
def visualize_rel(dataset, dist, name_alg, k):
  labels = Clustering.clust_methods(dist, name_alg, k)
  num_1 = np.sum(labels)
  matrix2 = Clustering.clustering_score(dataset, labels)
  matrix2_rel = Clustering.score_relative(dataset, matrix2, num_1)
  matrix2_rel.drop('CALI', axis = 1, inplace = True)
  clust_bar(matrix2_rel,k)

In [None]:
def clust_bar(matrix, k): 
  col = list(range(k))
  matrix['GROUP'] = col
  fig = px.bar(matrix, x = "GROUP", y = ['SOS', 'Fibrosis perisinusoidal', 'Fibrosis centrolobular', 'Peliosis', 'NRH', 'Steatosis', 'Lobular flogosis', 'Balooning', 'Steatohepatitis'])
  fig.show()
  matrix.drop('GROUP', axis = 1, inplace = True)
  matrix = matrix.T
  matrix = matrix.reset_index()
  matrix.rename({'index':'CALI'},axis = 1,inplace = True)
  fig = px.bar(matrix, x = "CALI", y = col)
  fig.show()

In [None]:
def clust_pie(matrix, k):
  matrix['GROUP'] = range(k)
  #for cali in matrix.columns[:9]:
  #  fig1 = px.pie(matrix, values = cali, names = 'GROUP', title = cali, width=1000)
  #  fig1.show()
  
  fig1 = make_subplots(rows=3, cols=3, specs=[[{'type':'domain'}]*3]*3);
  i = 1
  j = 1
  for cali in matrix.columns[:9]:
    fig1 = fig1.add_trace(go.Pie(labels = matrix['GROUP'], values = matrix[cali], title = cali),i,j)
    if j%3 == 0:
      i += 1
      j = 0
    j += 1
  fig1.show()    
  
  matrix.drop('GROUP', axis = 1, inplace = True)
  matrix = matrix.T
  matrix = matrix.reset_index()
  matrix.rename({'index':'CALI'},axis = 1,inplace = True)
  fig1 = make_subplots(rows=1, cols=k, specs=[[{'type':'domain'}]*k]);
  for i in range(k):
    title = "CALI GROUP " + str(i)
    #fig1 = px.pie(matrix, values = i, names = 'CALI', title = title )
    fig1 = fig1.add_trace(go.Pie(labels = matrix['CALI'], values = matrix[i], title = title),1,i+1)
  fig1.show()

In [None]:
def confusion_heatmap(mat):
      mat1=np.array(mat)

      #labels
      x=["0", "1"]
      y=["0", "1"]

      #create heatmap
      fig = ff.create_annotated_heatmap(mat1, x=x, y=y, colorscale='Viridis')
      fig.update_layout(title_text='<i><b>Confusion matrix</b></i>',
                        #xaxis = dict(title='x'),
                        #yaxis = dict(title='x')
                      )

      #annotations
      fig.add_annotation(dict(font=dict(color="black",size=14),
                              x=0.5,
                              y=-0.15,
                              showarrow=False,
                              text="Outcome",
                              xref="paper",
                              yref="paper"))
      fig.add_annotation(dict(font=dict(color="black",size=14),
                              x=-0.35,
                              y=0.5,
                              showarrow=False,
                              text="Cluster",
                              textangle=-90,
                              xref="paper",
                              yref="paper"))

      for i in range(len(fig.layout.annotations)):
          fig.layout.annotations[i].font.size = 40

      fig.update_layout(margin=dict(t=50, l=200))

      #colorscale
      fig['data'][0]['showscale'] = True

      fig.show()

# **1D_functions**

In [None]:
# plot of 1D functions
# arguments example:
#   y1,y2 array of values on y_axis
#   vol1,vol2 corrispondent volumes
#   name, array with indexes of patients in order


def plot_1D(*args):
  if len(args) % 2 == 1:
    fig = go.Figure()
    cont=0
    cont_2 = 0
    for i in range(len(args)-1):
      if not isinstance(args[i][1],list):
        fig=fig.add_trace(go.Scatter(x = np.array(range (len(args[i]))), y = args[i],mode='lines+markers', name=args[-1][cont]))
        cont+=1
      else:
         plot_3D(np.array(args[i]).T, 0, 127, args[-1][cont_2])
         cont_2 += 1
    fig.show()
  
  
  
  