This notebook is to extract all the static embedding (no fine tuning) from the different LLMs.

-> You can use this one to explore how prompting change the embedding for the LLM and use other models to extract embedding.

In [None]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
tqdm.pandas()

In [None]:
assert os.path.isfile('data/TGCA_Merged.csv'), 'Execute 0. Extraction.ipynb'

In [None]:
report = pd.read_csv('data/TGCA_Merged.csv', index_col = 0)
report.head()

### ChatGPT embedding

In [None]:
os.environ['OPENAI_API_KEY'] = open('KEY.txt', 'r').readline()

In [None]:
from openai import OpenAI
client = OpenAI()

In [None]:
assert False, "Are you sure you wanna run? $$$"
def get_embedding(text, model="text-embedding-ada-002"):
   return client.embeddings.create(input = [text], model = model).data[0].embedding

report_gpt = pd.DataFrame(report.text.apply(lambda x: get_embedding(x)).tolist(), index = report.index)
report_gpt.to_csv('data/gpt_embedding.csv')

### Bio Clinical BERT

In [None]:
def get_embedding(text, tokenizer, model):
   inputs = tokenizer(text, return_tensors="pt", truncation = True, max_length = 512)
   return model(**inputs).pooler_output.mean(0).detach().numpy()

In [None]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", output_attentions=True)

In [None]:
inputs = tokenizer(report.iloc[0].text, return_tensors="pt", truncation = True, max_length = 512)
output = model(**inputs)

In [None]:
report

In [None]:
assert False, "Are you sure you wanna run? ~80 min"
report_clinicalBERT = pd.DataFrame(report.text.progress_apply(lambda x: get_embedding(x, tokenizer, model)).tolist(), index = report.index)
report_clinicalBERT.to_csv('data/clinicalBERT_embedding.csv')

### BERT

In [None]:
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased")

In [None]:
assert False, "Are you sure you wanna run? ~80 min"
report_BERT = pd.DataFrame(report.text.progress_apply(lambda x: get_embedding(x, tokenizer, model)).tolist(), index = report.index)
report_BERT.to_csv('data/BERT_embedding.csv')

### Visualisation

In [None]:
from sklearn import manifold
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

In [None]:
embedding = pd.read_csv('data/gpt_embedding.csv', index_col = 0)
outcomes = pd.read_csv('data/TGCA_Merged.csv', index_col = 0)
embedding = (embedding - embedding.mean()) / embedding.std()

In [None]:
outcomes = outcomes.loc[outcomes.type.dropna().index]
embedding = embedding.loc[outcomes.index]

In [None]:
t_sne = manifold.TSNE(n_components = 2, random_state = 42)
embed_tsne = t_sne.fit_transform(embedding)

In [None]:
cmap = plt.get_cmap('viridis', 4)
plt.scatter(embed_tsne[:, 0], embed_tsne[:, 1], c = outcomes.ajcc_pathologic_tumor_stage.fillna(0), cmap = cmap, alpha = 0.5)
cbar = plt.colorbar(label = "Stage")
cbar.ax.set_yticks(0.75 * (np.arange(4) + 0.5), ['Unknown', 1, 2, 3])
plt.xlim(-100, 100)
plt.ylim(-100, 100)

In [None]:
plt.scatter(embed_tsne[:, 0], embed_tsne[:, 1], c = outcomes.t / 365., cmap = 'viridis', alpha = 0.5)
plt.colorbar(label = "Survival Time (in log years)")
plt.xlim(-100, 100)
plt.ylim(-100, 100)

In [None]:
grouping = outcomes.grouping.astype('category').cat
length = len(grouping.categories)
cmap = plt.get_cmap('viridis', length)
plt.scatter(embed_tsne[:, 0], embed_tsne[:, 1], c = grouping.codes, cmap = cmap, alpha = 0.75)
cbar = plt.colorbar(label = "Cancer Subgroups", ticks = np.arange(length))
cbar.ax.set_yticks( (length - 1) / length * (np.arange(length) + 0.5), grouping.categories)
plt.xlim(-100, 100)
plt.ylim(-100, 100)

In [None]:
embed_tsne = embed_tsne[outcomes.grouping == "Gynecological"]
embedding = embedding[outcomes.grouping == "Gynecological"]
outcomes = outcomes[outcomes.grouping == "Gynecological"]

In [None]:
length = len(outcomes.type.astype('category').cat.categories)
cmap = plt.get_cmap('viridis', length)
plt.scatter(embed_tsne[:, 0], embed_tsne[:, 1], c = outcomes.loc[embedding.index].type.astype('category').cat.codes, cmap = cmap, alpha = 0.5)
cbar = plt.colorbar(label = "Cancer Subgroups", ticks = np.arange(length))
cbar.ax.set_yticks( (length - 1) / length *(np.arange(length) + 0.5), outcomes.type.astype('category').cat.categories)
plt.xlim(-100, 100)
plt.ylim(-100, 100)