### load pubmed QA

In [1]:
### pubmed QA
from datasets import load_dataset, concatenate_datasets

# List of all the fold names
folds = [f"pubmed_qa_labeled_fold{i}_source" for i in range(1)]

# Load each fold and store in a list
datasets_list = [load_dataset("bigbio/pubmed_qa", fold, data_dir="/shared/s1/lab06/jiyongan/data/bigbio_pubmed_qa") for fold in folds]

# Concatenate all the loaded datasets
# Extract and concatenate each split separately
merged_train_datasets = concatenate_datasets([dataset["train"] for dataset in datasets_list if "train" in dataset])
merged_validation_datasets = concatenate_datasets([dataset["validation"] for dataset in datasets_list if "validation" in dataset])
merged_test_datasets = concatenate_datasets([dataset["test"] for dataset in datasets_list if "test" in dataset])

# Concatenate all the merged splits together
final_merged_dataset = concatenate_datasets([merged_train_datasets, merged_validation_datasets, merged_test_datasets])

In [2]:
generated_data2 = load_dataset("bigbio/pubmed_qa", "pubmed_qa_artificial_source", data_dir="/shared/s1/lab06/jiyongan/data/bigbio_pubmed_qa")
final_merged_dataset2 = concatenate_datasets([generated_data2['train'], generated_data2['validation']])

In [3]:
#concat human annotated and artifically generated
final_merged_dataset3 = concatenate_datasets([final_merged_dataset, final_merged_dataset2])

### extract Gene-related Question

In [4]:
# Your original set
my_set = {'gene', 'Gene', 'Genetic', 'Genome'}

# Elements to add
additional_elements = [
    'Genome', 'Genome Components', 'Genome Size', 'Genome, Archaeal', 
    'Genome, Bacterial', 'Genome, Chloroplast', 'Genome, Fungal', 
    'Genome, Helminth', 'Genome, Human', 'Genome, Insect', 'Genome, Microbial', 
    'Genome, Mitochondrial', 'Genome, Plant', 'Genome, Plastid', 
    'Genome, Protozoan', 'Genome, Viral', 'Genome-Wide Association Study', 
    'Genomic Imprinting', 'Genomic Instability', 'Genomic Islands', 
    'Genomic Library', 'Genomic Structural Variation', 'Genomics', 
    'Genotype', 'Genotyping Techniques'
]

# Update the set with the new elements
my_set.update(additional_elements)

# Now my_set contains the original elements plus the new ones
print(my_set)


{'Genome, Insect', 'Genetic', 'Genome, Plant', 'Genomic Instability', 'Gene', 'gene', 'Genome, Protozoan', 'Genome, Chloroplast', 'Genome, Human', 'Genome, Plastid', 'Genome, Fungal', 'Genome', 'Genomic Imprinting', 'Genome, Mitochondrial', 'Genomic Structural Variation', 'Genome, Bacterial', 'Genome, Microbial', 'Genome-Wide Association Study', 'Genotyping Techniques', 'Genome, Helminth', 'Genotype', 'Genome Components', 'Genome, Viral', 'Genomics', 'Genome, Archaeal', 'Genomic Islands', 'Genomic Library', 'Genome Size'}


In [5]:
data = []

for data_entry in final_merged_dataset3:
    if any(mesh in my_set for mesh in data_entry['MESHES']):
        data.append(data_entry)
        
print(len(data))

8026


In [6]:
data[0]

{'QUESTION': 'Does Molecular Genotype Provide Useful Information in the Management of Radioiodine Refractory Thyroid Cancers?',
 'CONTEXTS': ['Whether mutation status should be used to guide therapy is an important issue in many cancers. We correlated mutation profile in radioiodine-refractory (RAIR) metastatic thyroid cancers (TCs) with patient outcome and response to tyrosine kinase inhibitors (TKIs), and discussed the results with other published data.',
  'Outcome in 82 consecutive patients with metastatic RAIR thyroid carcinoma prospectively tested for BRAF, RAS and PI3KCA mutations was retrospectively analyzed, including 55 patients treated with multikinase inhibitors.',
  'Papillary thyroid carcinomas (PTCs) were the most frequent histological subtype (54.9 %), followed by poorly differentiated thyroid carcinoma [PDTC] (30.5 %) and follicular thyroid carcinoma [FTC](14.6 %). A genetic mutation was identified in 23 patients (28 %) and BRAF was the most frequently mutated gene (23

### preprocessing dataset

In [7]:
import random

def preprocess_data(data):
    formatted_data = []

    for i, item in enumerate(data):
        query = item['QUESTION']
        positive = item['LONG_ANSWER']
        
        negatives = []
        indices = list(range(len(data)))
        random.shuffle(indices)
        for j in indices:
            if j != i and len(negatives) < 2:  
                negatives.append(data[j]['LONG_ANSWER'])

        formatted_data.append({
            "query": query,
            "positive": positive,
            "negatives": negatives
        })

    return formatted_data

dataset = preprocess_data(data)


In [8]:
dataset[0]

{'query': 'Does Molecular Genotype Provide Useful Information in the Management of Radioiodine Refractory Thyroid Cancers?',
 'positive': 'Patients with BRAF-mutant PTC had a significantly longer PFS than BRAF wild-type when treated with TKIs. However, due to the small number of BRAF-mutant patients, further investigations are required, especially to understand the potential positive effect of BRAF mutations in RAIR TC patients while having a negative prognostic impact in RAI-sensitive PTC patients.',
 'negatives': ['Based on our data, approximately one in 4-5 individuals from the general population may be a carrier of null mutations that are responsible for HRD. This would be the highest mutation carrier frequency so far measured for a class of Mendelian disorders, especially considering that missenses and other forms of pathogenic changes were not included in our assessment. Among other things, our results indicate that the risk for a consanguineous couple of generating a child with 

### create custom dataset

In [10]:
### load dataset
import json
file_path = 'pubmed_gene_data.json'
with open(file_path, 'r') as file:
    dataset = json.load(file)

In [11]:
dataset[0]

{'query': 'Does Molecular Genotype Provide Useful Information in the Management of Radioiodine Refractory Thyroid Cancers?',
 'positive': 'Patients with BRAF-mutant PTC had a significantly longer PFS than BRAF wild-type when treated with TKIs. However, due to the small number of BRAF-mutant patients, further investigations are required, especially to understand the potential positive effect of BRAF mutations in RAIR TC patients while having a negative prognostic impact in RAI-sensitive PTC patients.',
 'negatives': ['This study confirms the high incidence of steatosis in patients infected by hepatitis C virus genotypes non-3, well linked to the development of fibrosis and metabolic abnormalities. Importantly, the present findings put emphasis on the early development of these metabolic abnormalities as they were already found in lean patients with chronic hepatitis C. The direct implication of hepatitis C virus is thus further stressed in the development of steatosis and insulin resist

In [12]:
from torch.utils.data import DataLoader, Dataset
import torch

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "query": item["query"],
            "positive": item["positive"],
            "negatives": item["negatives"]
        }

from torch.utils.data import DataLoader, random_split

dataset_size = len(dataset)
train_size = int(dataset_size * 0.8)  
val_size = dataset_size - train_size  

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)  # Usually, we don't need to shuffle the validation data


### processing batch

In [11]:
def process_batch(model, batch):
    # Tokenize the inputs
    inputs_query = tokenizer(batch['query'], padding=True, truncation=True, return_tensors="pt")
    inputs_positive = tokenizer(batch['positive'], padding=True, truncation=True, return_tensors="pt")

    # Tokenize each negative sample individually
    negative_embs = []
    for negative in batch['negatives']:
        inputs_negative = tokenizer(negative, padding=True, truncation=True, return_tensors="pt")
        inputs_negative = {k: v.to(model.device) for k, v in inputs_negative.items()}
        negative_emb = model(**inputs_negative).last_hidden_state[:, 0, :]
        negative_embs.append(negative_emb)

    # Move to the same device as the model
    inputs_query = {k: v.to(model.device) for k, v in inputs_query.items()}
    inputs_positive = {k: v.to(model.device) for k, v in inputs_positive.items()}

    # Get the embeddings from the model
    anchor_emb = model(**inputs_query).last_hidden_state[:, 0, :]
    positive_emb = model(**inputs_positive).last_hidden_state[:, 0, :]

    # Aggregate negative embeddings
    negative_emb = torch.stack(negative_embs).mean(dim=0)

    return anchor_emb, positive_emb, negative_emb

In [None]:
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import Dataset, Features, Value, ClassLabel
from transformers import BertModel
import torch
import torch.nn as nn
import torch.optim as optim

# Load the model
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-en-v1.5")
model = AutoModel.from_pretrained("BAAI/bge-large-en-v1.5")

# Define the optimizer with weight decay (L2 regularization)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01) 

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(1)
        distance_negative = (anchor - negative).pow(2).sum(1)
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

loss_fn = TripletLoss()

# Early Stopping Parameters
best_val_loss = float('inf')
epochs_no_improve = 0
n_epochs_stop = 3  # Number of epochs to wait for improvement before stopping

# Custom Training Loop
for epoch in range(10):
    model.train()
    total_loss = 0
    for batch in train_dataloader:
        # Process your batch to get anchor, positive, and negative embeddings
        anchor_emb, positive_emb, negative_emb = process_batch(model, batch)

        # Compute the loss
        loss = loss_fn(anchor_emb, positive_emb, negative_emb)
        print(loss)
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch}: Average Training Loss = {avg_train_loss}")

    # Validation step for early stopping
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for val_batch in val_dataloader:  # Assuming you have a validation dataloader

            avg_val_loss = total_val_loss / len(val_dataloader)
            print(f"Epoch {epoch}: Average Validation Loss = {avg_val_loss}")

            # Check for early stopping
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                epochs_no_improve = 0
                # Save the best model
                torch.save(model, 'finetuned_bge_large.bin')
            else:
                epochs_no_improve += 1
                if epochs_no_improve == n_epochs_stop:
                    print("Early stopping triggered")
                    break


tensor(11.9468, grad_fn=<MeanBackward0>)
tensor(21.3569, grad_fn=<MeanBackward0>)
