In [None]:
import pandas as pd

df_1 = pd.read_csv('/kaggle/input/nest-dataset-ps-1/Problem Statements and Data Sets/usecase_1_.csv')
df_2 = df = pd.read_excel('/kaggle/input/nest-dataset-ps-1/Problem Statements and Data Sets/usecase_4_.xlsx', sheet_name='ctg-studies')
df_3 = pd.read_csv('/kaggle/input/nest-dataset-ps-1/Problem Statements and Data Sets/usecase_3_.csv')

In [None]:
df_1.tail()

In [None]:
df_2.tail()

In [None]:
df_3.tail()

In [None]:
import pandas as pd

# Assuming df_1, df_2, and df_3 are already loaded

# Concatenate the dataframes vertically
combined_df = pd.concat([df_1, df_2, df_3], ignore_index=True)

# Rename 'NCT Number' to 'nct_id'
combined_df.rename(columns={'NCT Number': 'nct_id'}, inplace=True)

combined_df.tail() # To see the output, run the code.

In [None]:
import pandas as pd

# Read the eligibilities.txt file with the correct separator and chunksize
eligibilities_iter = pd.read_csv('/kaggle/input/nest-dataset-ps-1/Problem Statements and Data Sets/eligibilities.txt', sep='|', chunksize=50000)

# Initialize an empty list to store processed chunks
all_chunks = []

# Process the eligibilities.txt file in chunks
for chunk in eligibilities_iter:
    all_chunks.append(chunk)

# Combine all processed chunks
df_elig = pd.concat(all_chunks, ignore_index=True)

In [None]:
df_elig.tail()

In [None]:
import pandas as pd

# Merge the dataframes on 'nct_id' using an inner join
df_4 = pd.merge(combined_df, df_elig, on='nct_id', how='inner')


In [None]:
df_4.tail()

In [None]:
# columns present in df_1 but not in df_4
cols_in_df1_not_df4 = set(df_1.columns) - set(df_4.columns)

# columns present in df_4 but not in df_1
cols_in_df4_not_df1 = set(df_4.columns) - set(df_1.columns)

print("Columns in df_1 but not in df_4:", cols_in_df1_not_df4)
print("Columns in df_4 but not in df_1:", cols_in_df4_not_df1)

In [None]:
df = df_4[['nct_id', 'Study Title', 'Primary Outcome Measures', 'Secondary Outcome Measures', 'criteria', 'Funder Type']]

In [None]:
df.tail()

In [None]:
# Check for missing values in each column of the DataFrame 'df'
missing_values_count = df.isnull().sum()
missing_values_count

In [None]:
df = df.copy()  # Ensure df is a standalone DataFrame

# analyzing the text length for different columns, these text lengths include spaces as well.
text_columns = ['Study Title', 'Primary Outcome Measures', 'Secondary Outcome Measures', 'criteria']
for col in text_columns:
    df.loc[:, f'{col}_length'] = df[col].apply(lambda x: len(str(x)))

In [None]:
df.tail()

In [None]:
import numpy as np
print(np.mean(df['Study Title_length']))
print(np.mean(df['Primary Outcome Measures_length']))
print(np.mean(df['Secondary Outcome Measures_length']))
print(np.mean(df['criteria_length']))

In [None]:
import re

def count_special_chars(text):
    if pd.isna(text):  # Handle missing values
        return 0
    # Define a regex pattern to match specific special characters
    # Add or remove characters as needed
    special_char_pattern = r'[!@#$%^&*()_+{}\[\]:;"\'<>,.?/\\|`~\-=]'
    return len(re.findall(special_char_pattern, text))

# Columns to check for special characters
columns_to_check = ['Study Title', 'Primary Outcome Measures', 'Secondary Outcome Measures', 'criteria']

# Iterate over the specified columns and count special characters
special_char_counts = {}
for column in columns_to_check:
    special_char_counts[column] = df[column].apply(count_special_chars).sum()

# Print the results
for column, count in special_char_counts.items():
    print(f"Column '{column}' has {count} special characters.")

In [None]:
import re
import pandas as pd
from typing import Dict

# Medical abbreviations dictionary
medical_abbreviations = {
    'htn': 'hypertension',
    'mi': 'myocardial infarction',
    't2dm': 'type 2 diabetes mellitus',
    'aki': 'acute kidney injury',
    'aep': 'alcohol-exposed pregnancy',
    'bmi': 'body mass index',
    'allo': 'allogeneic',
    'aebp': 'alcohol exposure biomarker positive',
    'crp': 'c-reactive protein',
    'ckd': 'chronic kidney disease',
    'cvd': 'cardiovascular disease',
    'dm': 'diabetes mellitus',
    'hba1c': 'glycated hemoglobin',
    'hf': 'heart failure',
    'pae': 'potential alcohol exposure',
    'sae': 'serious adverse event',
    'sbp': 'systolic blood pressure',
    'dbp': 'diastolic blood pressure',
    # Add more abbreviations as needed
}

def standardize_units(text: str) -> str:
    """Standardize numerical values and units."""
    text = re.sub(r'(\d+)\s*years?', r'\1 year', text, flags=re.IGNORECASE)
    text = re.sub(r'(\d+)\s*days?', r'\1 day', text, flags=re.IGNORECASE)
    text = re.sub(r'(\d+)\s*weeks?', r'\1 week', text, flags=re.IGNORECASE)
    text = re.sub(r'(\d+)\s*months?', r'\1 month', text, flags=re.IGNORECASE)
    return text

def expand_abbreviations(text: str, abbrev_dict: Dict[str, str]) -> str:
    """Expand medical abbreviations to their full forms."""
    for abbrev, full_form in abbrev_dict.items():
        text = re.sub(r'\b' + abbrev + r'\b', full_form, text)
    return text

def preprocess_text(text: str) -> str:
    """
    Comprehensive text preprocessing function.
    """
    if pd.isna(text) or str(text).strip() == '':
        return ""
    
    # Convert to string and lowercase
    text = str(text).lower()
    
    # Preserve numerical values with units before general cleaning
    text = standardize_units(text)
    
    # Remove HTML tags if any
    text = re.sub(r'<.*?>', '', text)
    
    # Handle bullet points and lists
    text = re.sub(r'[\n\r]+', ' ', text)  # Replace newlines with space
    text = re.sub(r'[•\-*]+', '', text)   # Remove bullet points
    
    # Expand medical abbreviations
    text = expand_abbreviations(text, medical_abbreviations)
    
    # Remove special characters but preserve numbers and units
    text = re.sub(r'[^a-z0-9\s]', ' ', text)
    
    # Handle multiple spaces
    text = re.sub(r'\s+', ' ', text)
    
    return text.strip()

def process_dataframe(df: pd.DataFrame) -> pd.DataFrame:
    """
    Process the entire dataframe.
    """
    columns_to_process = ['Study Title', 'Primary Outcome Measures', 'Secondary Outcome Measures', 'criteria']
    
    # Add missing value flags
    df['Secondary_Outcome_Missing'] = df['Secondary Outcome Measures'].isna().astype(int)
    df['Primary_Outcome_Missing'] = df['Primary Outcome Measures'].isna().astype(int)
    
    # Preprocess and add cleaned text columns
    for column in columns_to_process:
        print(f"Processing {column}...")
        df[column + '_Cleaned'] = df[column].apply(preprocess_text)
    
    # Create new DataFrame with necessary columns
    new_columns = ['nct_id'] + \
                 [col + '_Cleaned' for col in columns_to_process] + \
                 ['Secondary_Outcome_Missing', 'Primary_Outcome_Missing']
    
    return df[new_columns]

# Main execution
if __name__ == "__main__":
    processed_df = process_dataframe(df)
    processed_df.to_csv('clinical_trials_cleaned.csv', index=False)

### Most of the data is cleaned. now basically gotta work towards the missing values.

In [None]:
import pandas as pd

In [None]:
df = pd.read_csv('/kaggle/input/nest-ps-1-cleaned/clinical_trials_cleaned.csv')

In [None]:
df.columns

In [None]:
missing_values = df.isnull().sum()
print(missing_values)

In [None]:
import pandas as pd
import numpy as np
from typing import Dict, List

def handle_missing_values(df: pd.DataFrame) -> pd.DataFrame:
    """
    Comprehensive missing value handling based on the proposed strategy
    """
    # Create a copy to avoid modifying original
    processed_df = df.copy()
    
    # 1. Secondary Outcome Measures (High missing rate)
    # We already have Secondary_Outcome_Missing flag
    # Replace missing values with standardized text
    processed_df['Secondary Outcome Measures_Cleaned'] = processed_df['Secondary Outcome Measures_Cleaned'].fillna('no secondary outcomes provided')
    
    # 2. Primary Outcome Measures (Moderate missing rate)
    # Create flag for missing primary outcomes
    processed_df['Primary_Outcome_Missing'] = processed_df['Primary Outcome Measures_Cleaned'].isna().astype(int)
    
    # Replace missing values with standardized text
    # We're using a placeholder here - in the next phase, we can implement similar-trial imputation
    processed_df['Primary Outcome Measures_Cleaned'] = processed_df['Primary Outcome Measures_Cleaned'].fillna('primary outcome not specified')
    
    # 3. Criteria (Low missing rate)
    # Create flag for missing criteria
    processed_df['Criteria_Missing'] = processed_df['criteria_Cleaned'].isna().astype(int)
    
    # Replace missing values with standardized text
    processed_df['criteria_Cleaned'] = processed_df['criteria_Cleaned'].fillna('no criteria provided')
    
    # 4. Create a data completeness score
    processed_df['completeness_score'] = 1.0
    # Reduce score for missing values
    processed_df.loc[processed_df['Secondary_Outcome_Missing'] == 1, 'completeness_score'] -= 0.2
    processed_df.loc[processed_df['Primary_Outcome_Missing'] == 1, 'completeness_score'] -= 0.4
    processed_df.loc[processed_df['Criteria_Missing'] == 1, 'completeness_score'] -= 0.4
    
    return processed_df

def generate_missing_data_report(original_df: pd.DataFrame, processed_df: pd.DataFrame) -> None:
    """
    Generate a detailed report of missing data handling
    """
    print("Missing Data Report")
    print("-" * 50)
    
    # Original missing values
    print("\nOriginal Missing Values:")
    for column in original_df.columns:
        missing_count = original_df[column].isna().sum()
        total_count = len(original_df)
        missing_percentage = (missing_count / total_count) * 100
        print(f"{column}: {missing_count} missing ({missing_percentage:.2f}%)")
    
    # Processed missing values
    print("\nProcessed Missing Values:")
    for column in processed_df.columns:
        missing_count = processed_df[column].isna().sum()
        total_count = len(processed_df)
        missing_percentage = (missing_count / total_count) * 100
        print(f"{column}: {missing_count} missing ({missing_percentage:.2f}%)")
    
    # Completeness score statistics
    print("\nCompleteness Score Statistics:")
    print(f"Mean: {processed_df['completeness_score'].mean():.2f}")
    print(f"Median: {processed_df['completeness_score'].median():.2f}")
    print(f"Min: {processed_df['completeness_score'].min():.2f}")
    print(f"Max: {processed_df['completeness_score'].max():.2f}")

# Main execution
if __name__ == "__main__":
    # Process the dataframe
    processed_df = handle_missing_values(df)
    
    # Generate report
    generate_missing_data_report(df, processed_df)
    
    # Save processed dataframe
    processed_df.to_csv('clinical_trials_processed_with_missing_handled.csv', index=False)

In [None]:
processed_df.tail()

### Since the data cleaning is done now. We are moving on to processing and converting our data to embeddings.

In [None]:
!pip install sentence-transformers
!pip install scikit-learn
!pip install transformers
!pip install torch
!pip install pandas
!pip install numpy

In [None]:
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import pandas as pd
from typing import List, Dict

In [None]:
df = pd.read_csv('/kaggle/input/nest-ps-1-data-cleaning-done/clinical_trials_processed_with_missing_handled.csv')

In [None]:
df.columns

In [9]:
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
import torch
import numpy as np
from typing import List, Dict
from tqdm.auto import tqdm
import time
import psutil
import gc
import os

def get_memory_usage():
    process = psutil.Process()
    memory_use = process.memory_info().rss / (1024 ** 2)  # in MB
    return memory_use

class TrialEncoder:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        print(f"Using device: {self.device}")
        print(f"Memory before loading SBERT model: {get_memory_usage():.2f} MB")
        
        # Load model with optimizations
        self.sbert_model = SentenceTransformer('pritamdeka/S-PubMedBert-MS-MARCO')
        
        # Convert to half precision if using CUDA
        if self.device == 'cuda':
            self.sbert_model.half()  # Use FP16 for faster processing
        
        self.sbert_model.to(self.device)
        print(f"Memory after loading SBERT model: {get_memory_usage():.2f} MB")
        
        # Initialize TF-IDF with optimized parameters
        self.tfidf = TfidfVectorizer(
            max_features=10000, 
            stop_words='english',
            dtype=np.float32  # Use float32 instead of float64 for memory efficiency
        )
    
    def combine_text_fields(self, row: pd.Series) -> str:
        return (f"TITLE: {row['Study Title_Cleaned']} "
                f"PRIMARY: {row['Primary Outcome Measures_Cleaned']} "
                f"SECONDARY: {row['Secondary Outcome Measures_Cleaned']} "
                f"CRITERIA: {row['criteria_Cleaned']}")

    @torch.cuda.amp.autocast()  # Enable automatic mixed precision
    def encode_sbert(self, texts: List[str]) -> np.ndarray:
        return self.sbert_model.encode(
            texts,
            show_progress_bar=True,
            device=self.device,
            batch_size=256,  # Increased batch size for better GPU utilization
            normalize_embeddings=True,
            convert_to_numpy=True
        )
    
    def encode_tfidf(self, texts: List[str]) -> np.ndarray:
        return self.tfidf.fit_transform(texts).astype(np.float32).toarray()

def process_in_chunks(df: pd.DataFrame, chunk_size: int = 5000, save_path: str = './embeddings/') -> None:
    os.makedirs(save_path, exist_ok=True)
    
    # Initialize encoder
    encoder = TrialEncoder()
    total_chunks = len(df) // chunk_size + (1 if len(df) % chunk_size != 0 else 0)
    
    # Initialize arrays and progress bar
    all_sbert = []
    all_tfidf = []
    
    # Main progress bar
    with tqdm(total=total_chunks, desc="Processing chunks", position=0) as chunk_pbar:
        try:
            for chunk_idx in range(total_chunks):
                start_idx = chunk_idx * chunk_size
                end_idx = min((chunk_idx + 1) * chunk_size, len(df))
                
                # Update progress description
                chunk_pbar.set_description(
                    f"Chunk {chunk_idx + 1}/{total_chunks} [Rows {start_idx}-{end_idx}]"
                )
                
                # Process chunk
                chunk_df = df.iloc[start_idx:end_idx]
                
                # Combine texts efficiently
                combined_texts = [
                    encoder.combine_text_fields(row) 
                    for _, row in tqdm(chunk_df.iterrows(), 
                                     desc="Combining texts", 
                                     position=1, 
                                     leave=False)
                ]
                
                # Generate embeddings
                sbert_embeddings = encoder.encode_sbert(combined_texts)
                all_sbert.append(sbert_embeddings)
                
                tfidf_vectors = encoder.encode_tfidf(combined_texts)
                all_tfidf.append(tfidf_vectors)
                
                # Clear memory
                del combined_texts, sbert_embeddings, tfidf_vectors
                gc.collect()
                torch.cuda.empty_cache()
                
                chunk_pbar.update(1)
                
                # Save backup every 5 chunks
                if (chunk_idx + 1) % 5 == 0:
                    timestamp = time.strftime("%Y%m%d-%H%M%S")
                    backup_path = os.path.join(save_path, 'backups')
                    os.makedirs(backup_path, exist_ok=True)
                    
                    np.save(os.path.join(backup_path, f'sbert_backup_{timestamp}.npy'), 
                           np.vstack(all_sbert))
                    np.save(os.path.join(backup_path, f'tfidf_backup_{timestamp}.npy'), 
                           np.vstack(all_tfidf))
                    
                    chunk_pbar.write(f"Backup saved at chunk {chunk_idx + 1}")
            
            # Save final combined embeddings
            chunk_pbar.write("\nSaving final embeddings...")
            timestamp = time.strftime("%Y%m%d-%H%M%S")
            
            final_sbert = np.vstack(all_sbert)
            np.save(os.path.join(save_path, f'sbert_embeddings_final_{timestamp}.npy'), 
                   final_sbert)
            chunk_pbar.write(f"Final SBERT shape: {final_sbert.shape}")
            del final_sbert
            
            final_tfidf = np.vstack(all_tfidf)
            np.save(os.path.join(save_path, f'tfidf_vectors_final_{timestamp}.npy'), 
                   final_tfidf)
            chunk_pbar.write(f"Final TF-IDF shape: {final_tfidf.shape}")
            del final_tfidf
            
        except Exception as e:
            chunk_pbar.write(f"Error occurred: {str(e)}")
            raise

if __name__ == "__main__":
    try:
        total_start_time = time.time()
        print(f"Starting processing with initial memory: {get_memory_usage():.2f} MB")
        
        process_in_chunks(df, chunk_size=4000)
        
        total_time = time.time() - total_start_time
        print(f"\nTotal execution time: {total_time/60:.2f} minutes")
        
    except Exception as e:
        print(f"Error during processing: {str(e)}")

ModuleNotFoundError: No module named 'sentence_transformers'

### Now we are simply checking the quality of the embeddings by trying to predict the top_k titls simply using cosine similarity

In [None]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict, List, Tuple
import time

class SimilarityRetriever:
    def __init__(self, 
                 sbert_embeddings: np.ndarray, 
                 tfidf_embeddings: np.ndarray,
                 alpha: float = 0.7):
        """
        Initialize with both types of embeddings
        
        Args:
            sbert_embeddings: SBERT embeddings array
            tfidf_embeddings: TF-IDF embeddings array
            alpha: Weight for SBERT similarity (1-alpha for TF-IDF)
        """
        # Verify and convert embeddings to float32
        try:
            self.sbert_embeddings = np.array(sbert_embeddings, dtype=np.float32)
            self.tfidf_embeddings = np.array(tfidf_embeddings, dtype=np.float32)
        except ValueError as e:
            print("Error converting embeddings to float32. Please check your embedding files.")
            raise
            
        self.alpha = alpha
        
        # Normalize embeddings for faster cosine similarity
        try:
            self.sbert_normalized = self._normalize_embeddings(self.sbert_embeddings)
            self.tfidf_normalized = self._normalize_embeddings(self.tfidf_embeddings)
        except Exception as e:
            print(f"Error during normalization: {str(e)}")
            raise
        
        print(f"Initialized with shapes: SBERT {self.sbert_embeddings.shape}, TF-IDF {self.tfidf_embeddings.shape}")
    
    def _normalize_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
        """Normalize embeddings for cosine similarity"""
        norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
        # Avoid division by zero
        norms[norms == 0] = 1e-10
        return embeddings / norms

# Loading and verification function
def load_embeddings(sbert_path: str, tfidf_path: str) -> Tuple[np.ndarray, np.ndarray]:
    """
    Load and verify embeddings from files
    
    Args:
        sbert_path: Path to SBERT embeddings
        tfidf_path: Path to TF-IDF embeddings
    
    Returns:
        Tuple of (SBERT embeddings, TF-IDF embeddings)
    """
    try:
        print(f"Loading SBERT embeddings from {sbert_path}")
        sbert = np.load(sbert_path)
        print(f"SBERT embeddings shape: {sbert.shape}")
        
        print(f"Loading TF-IDF embeddings from {tfidf_path}")
        tfidf = np.load(tfidf_path)
        print(f"TF-IDF embeddings shape: {tfidf.shape}")
        
        # Verify data types and contents
        if not np.issubdtype(sbert.dtype, np.number):
            raise ValueError(f"SBERT embeddings contain non-numeric data: {sbert.dtype}")
        if not np.issubdtype(tfidf.dtype, np.number):
            raise ValueError(f"TF-IDF embeddings contain non-numeric data: {tfidf.dtype}")
            
        # Check for NaN or infinite values
        if np.isnan(sbert).any() or np.isinf(sbert).any():
            raise ValueError("SBERT embeddings contain NaN or infinite values")
        if np.isnan(tfidf).any() or np.isinf(tfidf).any():
            raise ValueError("TF-IDF embeddings contain NaN or infinite values")
            
        return sbert, tfidf
        
    except Exception as e:
        print(f"Error loading embeddings: {str(e)}")
        raise


In [None]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from scipy.sparse import csr_matrix
import torch
from typing import Dict, List, Tuple
import time

class SimilarityRetriever:
    def __init__(self, 
                 sbert_embeddings: np.ndarray, 
                 tfidf_embeddings: csr_matrix,
                 alpha: float = 0.7):
        """
        Initialize with both types of embeddings
        
        Args:
            sbert_embeddings: SBERT embeddings array
            tfidf_embeddings: TF-IDF embeddings sparse matrix
            alpha: Weight for SBERT similarity (1-alpha for TF-IDF)
        """
        # Verify and convert embeddings to float32
        self.sbert_embeddings = np.array(sbert_embeddings, dtype=np.float32)
        self.tfidf_embeddings = tfidf_embeddings.astype(np.float32)
        
        self.alpha = alpha
        
        # Normalize SBERT embeddings for cosine similarity
        sbert_norms = np.linalg.norm(self.sbert_embeddings, axis=1, keepdims=True)
        sbert_norms[sbert_norms == 0] = 1e-10
        self.sbert_normalized = torch.from_numpy(self.sbert_embeddings / sbert_norms).cuda()
        
        # TF-IDF is kept sparse and normalized already
        # Assuming TF-IDF is already L2 normalized; if not, normalize it
        # For sparse matrices, assume they are normalized
        
        print(f"Initialized with shapes: SBERT {self.sbert_embeddings.shape}, TF-IDF {self.tfidf_embeddings.shape}")
    
    def find_similar_trials(self, query_idx: int, top_k: int = 10) -> List[Tuple[int, float]]:
        """
        Find top-k most similar trials to the query trial
        
        Args:
            query_idx: Index of the query trial
            top_k: Number of top similar trials to retrieve
        
        Returns:
            List of tuples (index, combined_similarity_score)
        """
        # Compute SBERT similarities
        query_sbert = self.sbert_normalized[query_idx].unsqueeze(0)
        sbert_similarities = torch.mm(query_sbert, self.sbert_normalized.t()).squeeze()
        sbert_similarities = sbert_similarities.cpu().numpy()
        
        # Compute TF-IDF similarities
        query_tfidf = self.tfidf_embeddings[query_idx]
        tfidf_similarities = self.tfidf_embeddings.dot(query_tfidf.T).A.flatten()
        
        # Combine similarities
        combined_similarities = (self.alpha * sbert_similarities +
                                 (1 - self.alpha) * tfidf_similarities)
        
        # Get top-k indices
        top_indices = np.argsort(combined_similarities)[::-1][:top_k]
        top_scores = combined_similarities[top_indices]
        
        return list(zip(top_indices, top_scores))
    
    def analyze_similarity_components(self, query_idx: int, trial_idx: int) -> Dict[str, float]:
        """
        Analyze similarity components for a pair of trials
        
        Args:
            query_idx: Index of the query trial
            trial_idx: Index of the target trial
        
        Returns:
            Dictionary with 'sbert_similarity' and 'tfidf_similarity'
        """
        sbert_sim = torch.dot(self.sbert_normalized[query_idx], self.sbert_normalized[trial_idx]).item()
        tfidf_sim = self.tfidf_embeddings[trial_idx].dot(self.tfidf_embeddings[query_idx].T).toarray()[0, 0]
        
        return {
            'sbert_similarity': sbert_sim,
            'tfidf_similarity': tfidf_sim
        }

# Loading and verification function
def load_embeddings(sbert_path: str, tfidf_path: str) -> Tuple[np.ndarray, csr_matrix]:
    """
    Load and verify embeddings from files
    
    Args:
        sbert_path: Path to SBERT embeddings
        tfidf_path: Path to TF-IDF embeddings
    
    Returns:
        Tuple of (SBERT embeddings, TF-IDF embeddings as sparse matrix)
    """
    try:
        print(f"Loading SBERT embeddings from {sbert_path}")
        sbert = np.load(sbert_path)
        print(f"SBERT embeddings shape: {sbert.shape}")
        
        print(f"Loading TF-IDF embeddings from {tfidf_path}")
        tfidf_sparse = np.load(tfidf_path, allow_pickle=True)
        tfidf = csr_matrix((tfidf_sparse['data'], 
                            tfidf_sparse['indices'], 
                            tfidf_sparse['indptr']),
                           shape=tuple(tfidf_sparse['shape']))
        print(f"TF-IDF embeddings shape: {tfidf.shape}")
        
        # Verify data types and contents
        if not np.issubdtype(sbert.dtype, np.number):
            raise ValueError(f"SBERT embeddings contain non-numeric data: {sbert.dtype}")
        if not np.issubdtype(tfidf.dtype, np.number):
            raise ValueError(f"TF-IDF embeddings contain non-numeric data: {tfidf.dtype}")
            
        # Check for NaN or infinite values
        if np.isnan(sbert).any() or np.isinf(sbert).any():
            raise ValueError("SBERT embeddings contain NaN or infinite values")
        if np.isnan(tfidf.data).any() or np.isinf(tfidf.data).any():
            raise ValueError("TF-IDF embeddings contain NaN or infinite values")
            
        return sbert, tfidf
        
    except Exception as e:
        print(f"Error loading embeddings: {str(e)}")
        raise

# Load and check embeddings
sbert_path = '/kaggle/input/nest-ps-1-sbert-embeddings/sbert_embeddings_final_20241228-170631.npy'
tfidf_path = '/kaggle/input/nest-ps-1-tfidf-embeddings/tfidf_vectors_final_20241228-170631.npz'

sbert, tfidf_sparse = load_embeddings(sbert_path, tfidf_path)

print("SBERT shape:", sbert.shape)
print("SBERT dtype:", sbert.dtype)
print("TF-IDF shape:", tfidf_sparse.shape)
print("TF-IDF dtype:", tfidf_sparse.dtype)

# Now initialize the retriever with the dense arrays
retriever = SimilarityRetriever(sbert, tfidf_sparse, alpha=0.7)

# Test the retriever
query_idx = 0
similar_trials = retriever.find_similar_trials(query_idx, top_k=10)

# Print results
print(f"\nMost similar trials to trial {query_idx}:")
for idx, score in similar_trials:
    analysis = retriever.analyze_similarity_components(query_idx, idx)
    print(f"\nTrial {idx}")
    print(f"Combined similarity: {score:.4f}")
    print(f"SBERT similarity: {analysis['sbert_similarity']:.4f}")
    print(f"TF-IDF similarity: {analysis['tfidf_similarity']:.4f}")

In [None]:
import pandas as pd
df = pd.read_csv('/kaggle/input/nest-ps-1-data-cleaning-done/clinical_trials_processed_with_missing_handled.csv')

In [None]:
df.columns

In [None]:
print(df['Study Title_Cleaned'][0])
print(df['Study Title_Cleaned'][1])
print(df['Study Title_Cleaned'][173809])
print(df['Study Title_Cleaned'][173810])
print(df['Study Title_Cleaned'][200735])
print(df['Study Title_Cleaned'][200734])
print(df['Study Title_Cleaned'][131525])
print(df['Study Title_Cleaned'][124411])
print(df['Study Title_Cleaned'][124412])
print(df['Study Title_Cleaned'][87451])

### Since cosine similarity seems to be working fine. Now we are moving on to implementing multiple metrics so that we can get a better overview at things.

In [None]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from scipy.sparse import csr_matrix
import torch
from typing import Dict, List, Tuple
import time

class SimilarityMetrics:
    @staticmethod
    def euclidean_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Convert Euclidean distance to similarity score"""
        distances = torch.cdist(x.unsqueeze(0), y, p=2).squeeze()
        return 1 / (1 + distances)
    
    @staticmethod
    def manhattan_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Convert Manhattan distance to similarity score"""
        distances = torch.cdist(x.unsqueeze(0), y, p=1).squeeze()
        return 1 / (1 + distances)
    
    @staticmethod
    def sparse_euclidean_similarity(x: csr_matrix, y: csr_matrix) -> np.ndarray:
        """Compute Euclidean similarity for sparse matrices"""
        # Convert query vector to dense array for operations
        x_array = x.toarray().ravel()
        
        # Compute squared norms
        x_squared = np.sum(x_array ** 2)
        y_squared = np.asarray(y.multiply(y).sum(axis=1)).ravel()
        
        # Compute dot product
        dot_product = y.dot(x.T).toarray().ravel()
        
        # Compute distances using broadcasting
        distances = np.sqrt(np.maximum(x_squared + y_squared - 2 * dot_product, 0))
        return 1 / (1 + distances)


class SimilarityRetriever:
    def __init__(self, 
                 sbert_embeddings: np.ndarray, 
                 tfidf_embeddings: csr_matrix,
                 similarity_weights: Dict[str, float] = None):
        """
        Initialize with embeddings and similarity weights
        
        Args:
            sbert_embeddings: SBERT embeddings array
            tfidf_embeddings: TF-IDF embeddings sparse matrix
            similarity_weights: Dictionary of weights for different similarity measures
                                Default: {'cosine_sbert': 0.4, 'euclidean_sbert': 0.3, 'manhattan_sbert': 0.3,
                                          'cosine_tfidf': 0.4, 'euclidean_tfidf': 0.6}
        """
        # Set default weights if none provided
        self.similarity_weights = similarity_weights or {
            'cosine_sbert': 0.4,
            'euclidean_sbert': 0.3,
            'manhattan_sbert': 0.3,
            'cosine_tfidf': 0.4,
            'euclidean_tfidf': 0.6
        }
        
        # Normalize weights if they do not sum to 1.0
        if abs(sum(self.similarity_weights.values()) - 1.0) > 1e-6:
            total = sum(self.similarity_weights.values())
            if total == 0:
                raise ValueError("Similarity weights cannot sum to zero.")
            self.similarity_weights = {k: v / total for k, v in self.similarity_weights.items()}
            print("Weights do not sum to 1.0; normalizing them.")
        
        # Initialize embeddings
        self.sbert_embeddings = np.array(sbert_embeddings, dtype=np.float32)
        self.tfidf_embeddings = tfidf_embeddings.astype(np.float32)
        
        # Normalize SBERT embeddings and move to GPU
        sbert_norms = np.linalg.norm(self.sbert_embeddings, axis=1, keepdims=True)
        sbert_norms[sbert_norms == 0] = 1e-10
        self.sbert_normalized = torch.from_numpy(self.sbert_embeddings / sbert_norms).cuda()
        
        print(f"Initialized with shapes: SBERT {self.sbert_embeddings.shape}, TF-IDF {self.tfidf_embeddings.shape}")
    
    def compute_similarities(self, query_idx: int) -> Dict[str, np.ndarray]:
        """Compute all similarity measures for a query"""
        # Get query vectors
        query_sbert = self.sbert_normalized[query_idx].unsqueeze(0)
        query_tfidf = self.tfidf_embeddings[query_idx]
        
        similarities = {}
        
        # SBERT similarities (GPU)
        similarities['cosine_sbert'] = torch.mm(query_sbert, self.sbert_normalized.t()).squeeze().cpu().numpy()
        similarities['euclidean_sbert'] = SimilarityMetrics.euclidean_similarity(
            query_sbert, self.sbert_normalized).cpu().numpy()
        similarities['manhattan_sbert'] = SimilarityMetrics.manhattan_similarity(
            query_sbert, self.sbert_normalized).cpu().numpy()
        
        # TF-IDF similarities (CPU, sparse)
        similarities['cosine_tfidf'] = self.tfidf_embeddings.dot(query_tfidf.T).toarray().ravel()
        similarities['euclidean_tfidf'] = SimilarityMetrics.sparse_euclidean_similarity(
            query_tfidf, self.tfidf_embeddings)
        
        return similarities


    
    def find_similar_trials(self, query_idx: int, top_k: int = 10) -> List[Tuple[int, float, Dict[str, float]]]:
        """
        Find top-k most similar trials using ensemble of similarity measures
        
        Returns:
            List of tuples (index, combined_score, individual_scores)
        """
        # Compute all similarities
        similarities = self.compute_similarities(query_idx)
        
        # Combine similarities with weights
        combined_similarities = (
            self.similarity_weights['cosine_sbert'] * similarities['cosine_sbert'] +
            self.similarity_weights['euclidean_sbert'] * similarities['euclidean_sbert'] +
            self.similarity_weights['manhattan_sbert'] * similarities['manhattan_sbert'] +
            self.similarity_weights['cosine_tfidf'] * similarities['cosine_tfidf'] +
            self.similarity_weights['euclidean_tfidf'] * similarities['euclidean_tfidf']
        )
        
        # Get top-k results
        top_indices = np.argsort(combined_similarities)[::-1][:top_k]
        
        results = []
        for idx in top_indices:
            individual_scores = {name: sim[idx] for name, sim in similarities.items()}
            results.append((idx, combined_similarities[idx], individual_scores))
        
        return results

# Example usage:
if __name__ == "__main__":
    # Custom weights example
    weights = {
        'cosine_sbert': 0.4,
        'euclidean_sbert': 0.3,
        'manhattan_sbert': 0.3,
        'cosine_tfidf': 0.4,
        'euclidean_tfidf': 0.6
    }
    
    # Load your embeddings here
    # sbert = np.load('path_to_sbert_embeddings.npy')
    # tfidf_sparse = csr_matrix(np.load('path_to_tfidf_embeddings.npy'))
    
    retriever = SimilarityRetriever(sbert, tfidf_sparse, similarity_weights=weights)
    
    query_idx = 0
    similar_trials = retriever.find_similar_trials(query_idx, top_k=10)
    
    print(f"\nMost similar trials to trial {query_idx}:")
    for idx, combined_score, individual_scores in similar_trials:
        print(f"\nTrial {idx}")
        print(f"Combined similarity: {combined_score:.4f}")
        for metric, score in individual_scores.items():
            print(f"{metric}: {score:.4f}")

### Now for the model development phase we are going to follow the following guidelines:
Phase 1: Data Preparation and Label Generation

    Create Training Pairs

    Define Similarity Labels

    Option A: Use existing embeddings to create silver-standard labels
    Option B: Use domain-specific rules (same disease area, intervention)
    Option C: Get expert annotations for a subset
Phase 2: Model Architecture Design

    Siamese Network Base
    Feature Integration
Phase 3: Training Pipeline Setup

    Data Loading
    Training Loop
Phase 4: Model Evaluation

    Metrics Implementation


#### Starting with phase 1

In [1]:
# Installing some dependecies for better performance
! pip install psutil tqdm joblib torch pandas numpy scipy



In [2]:
from tqdm.notebook import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from joblib import Parallel, delayed
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.cuda.amp as amp
import gc
import time
import psutil
from contextlib import contextmanager
from typing import List, Tuple, Dict
from scipy.sparse import csr_matrix
import traceback

@contextmanager
def timer(name: str):
    """Simple timing context manager"""
    start = time.time()
    yield
    print(f"{name} took {time.time() - start:.2f} seconds")
    
def process_tfidf_chunk(start_idx: int, end_idx: int, batch_start: int, batch_end: int, tfidf_matrix: csr_matrix) -> np.ndarray:
    """
    Process a chunk of TF-IDF matrix to compute similarities.
    
    Args:
        start_idx: Start index of the chunk
        end_idx: End index of the chunk
        batch_start: Start index of the batch
        batch_end: End index of the batch
        tfidf_matrix: The TF-IDF sparse matrix
        
    Returns:
        np.ndarray: Cosine similarities for the chunk
    """
    return cosine_similarity(
        tfidf_matrix[batch_start:batch_end],
        tfidf_matrix[start_idx:end_idx]
    )

def compute_tfidf_pair(tfidf_matrix: csr_matrix, x: int, y: int) -> float:
    """
    Standalone function to avoid lambda references for joblib.
    Computes the TF-IDF cosine similarity between two rows in the TF-IDF matrix.
    """
    return cosine_similarity(
        tfidf_matrix[x].reshape(1, -1),
        tfidf_matrix[y].reshape(1, -1)
    )[0][0]

class DataPairGenerator:
    def __init__(self, 
                 df: pd.DataFrame,
                 sbert_embeddings: np.ndarray,
                 tfidf_embeddings: csr_matrix,
                 sbert_threshold: float = 0.8,
                 tfidf_threshold: float = 0.5,
                 pos_neg_ratio: float = 1.0,
                 batch_size: int = 512):
        """
        Initialize the pair generator with simplified GPU references
        to avoid unpicklable objects when using joblib.
        """
        print("Initializing DataPairGenerator...")
        self.df = df
        print(f"Loaded DataFrame with {len(df)} rows.")
        self.tfidf_embeddings = tfidf_embeddings
        print(f"Loaded Tf-IDF embeddings with shape {tfidf_embeddings.shape}.")
        self.sbert_threshold = sbert_threshold
        self.tfidf_threshold = tfidf_threshold
        self.pos_neg_ratio = pos_neg_ratio
        self.batch_size = batch_size
        
        # Convert SBERT embeddings to torch tensor (float16 for memory efficiency)
        print("Converting SBERT embeddings to PyTorch tensor...")
        self.sbert_embeddings = torch.from_numpy(np.array(sbert_embeddings, dtype=np.float16))
        
        # Create index mapping for NCT IDs
        self.nct_to_idx = {nct: idx for idx, nct in enumerate(df['nct_id'])}
        
        # Check GPU availability
        if torch.cuda.is_available():
            print(f"Using GPU: {torch.cuda.get_device_name(0)}")
            self.sbert_embeddings = self.sbert_embeddings.cuda()
        else:
            print("No GPU available. Using CPU only.")
        
        print("Initialization complete.")

    def compute_batch_similarities(self, batch_start: int, batch_end: int) -> Tuple[np.ndarray, np.ndarray]:
        """
        Compute SBERT and TF-IDF similarities for a batch.
        """
        chunk_size = 2000
        if torch.cuda.is_available():
            with torch.cuda.device(0):
                batch_sbert = self.sbert_embeddings[batch_start:batch_end].cuda().float()
                sbert_sims_list = []
                
                # Add progress bar for SBERT similarity computation
                chunks = range(0, len(self.sbert_embeddings), chunk_size)
                with tqdm(total=len(chunks), desc="Computing SBERT similarities", leave=False) as pbar:
                    with torch.autocast("cuda", dtype=torch.float16):
                        for i in chunks:
                            chunk_end = min(i + chunk_size, len(self.sbert_embeddings))
                            chunk_sbert = self.sbert_embeddings[i:chunk_end].cuda().float()
                            chunk_sims = F.cosine_similarity(
                                batch_sbert.unsqueeze(1),
                                chunk_sbert.unsqueeze(0),
                                dim=2
                            )
                            sbert_sims_list.append(chunk_sims.cpu().numpy())
                            del chunk_sbert
                            pbar.update(1)
                            
                sbert_sims = np.concatenate(sbert_sims_list, axis=1)
                del batch_sbert, sbert_sims_list
                torch.cuda.empty_cache()
        else:
            batch_sbert = self.sbert_embeddings[batch_start:batch_end].float()
            sbert_sims_list = []
            chunks = range(0, len(self.sbert_embeddings), chunk_size)
            with tqdm(total=len(chunks), desc="Computing SBERT similarities", leave=False) as pbar:
                for i in chunks:
                    chunk_end = min(i + chunk_size, len(self.sbert_embeddings))
                    chunk_sbert = self.sbert_embeddings[i:chunk_end].float()
                    chunk_sims = F.cosine_similarity(
                        batch_sbert.unsqueeze(1),
                        chunk_sbert.unsqueeze(0),
                        dim=2
                    )
                    sbert_sims_list.append(chunk_sims.numpy())
                    pbar.update(1)
            sbert_sims = np.concatenate(sbert_sims_list, axis=1)
        
        # Parallelize TF-IDF computation
        n_jobs = min(psutil.cpu_count() - 1, 8)
        chunk_size = min(5000, self.tfidf_embeddings.shape[0] // n_jobs)
        chunks = range(0, self.tfidf_embeddings.shape[0], chunk_size)
        
        tfidf_chunks = Parallel(n_jobs=n_jobs, prefer='threads')(
            delayed(process_tfidf_chunk)(i, min(i + chunk_size, self.tfidf_embeddings.shape[0]), batch_start, batch_end, self.tfidf_embeddings)
            for i in tqdm(chunks, desc="TF-IDF chunks", leave=False)
        )
        tfidf_sims = np.hstack(tfidf_chunks)
        
        return sbert_sims, tfidf_sims

    def generate_positive_pairs(self, sample_size: int = None) -> List[Tuple[str, str]]:
        positive_pairs = []
        num_trials = len(self.df)
        
        print("\nGenerating positive pairs...")
        with tqdm(total=num_trials, desc="Processing batches") as pbar:
            for i in range(0, num_trials, self.batch_size):
                batch_end = min(i + self.batch_size, num_trials)
                sbert_sims, tfidf_sims = self.compute_batch_similarities(i, batch_end)
                
                for idx in range(batch_end - i):
                    local_pairs = []
                    global_idx = i + idx
                    nct_id = self.df.iloc[global_idx]['nct_id']

                    similar_mask = (sbert_sims[idx] > self.sbert_threshold) & \
                                   (tfidf_sims[idx] > self.tfidf_threshold)
                    similar_indices = np.where(similar_mask)[0]
                    similar_indices = similar_indices[similar_indices != global_idx]

                    similar_ncts = self.df.iloc[similar_indices]['nct_id'].values
                    local_pairs.extend(list(zip([nct_id] * len(similar_ncts), similar_ncts)))
                    positive_pairs.extend(local_pairs)

                pbar.update(batch_end - i)
                del sbert_sims, tfidf_sims
                gc.collect()

        if sample_size:
            positive_pairs = np.random.choice(
                positive_pairs,
                size=min(sample_size, len(positive_pairs)),
                replace=False
            ).tolist()
        
        print(f"Generated {len(positive_pairs)} positive pairs")
        return positive_pairs

    def generate_negative_pairs(self, num_pairs: int) -> List[Tuple[str, str]]:
        """
        Generate negative pairs with handling for small datasets.
        """
        negative_pairs = []
        num_trials = len(self.df)
        
        # Calculate maximum possible unique pairs
        max_possible_pairs = (num_trials * (num_trials - 1)) // 2
        
        # Adjust batch size based on dataset size
        batch_size = min(1000, max(10, num_trials // 2))
        
        # Adjust number of pairs if necessary
        num_pairs = min(num_pairs, max_possible_pairs)
        print(f"Aiming to generate {num_pairs} negative pairs (maximum possible: {max_possible_pairs})")
        
        print("\nGenerating negative pairs...")
        with tqdm(total=num_pairs, desc="Finding negative pairs") as pbar:
            while len(negative_pairs) < num_pairs:
                # Calculate how many more pairs we need
                pairs_needed = num_pairs - len(negative_pairs)
                
                # Adjust batch size if we're near the end
                current_batch_size = min(batch_size, pairs_needed * 2)
                
                # Generate random indices with replacement for small datasets
                idx1 = np.random.randint(0, num_trials, size=current_batch_size)
                idx2 = np.random.randint(0, num_trials, size=current_batch_size)
                
                # Remove self-pairs
                valid_mask = idx1 != idx2
                idx1 = idx1[valid_mask]
                idx2 = idx2[valid_mask]
                
                if len(idx1) == 0:
                    continue
                
                # Reshape for similarity computation
                indices = np.column_stack((idx1, idx2))
                
                if torch.cuda.is_available():
                    with torch.cuda.device(0):
                        sbert_sims = F.cosine_similarity(
                            self.sbert_embeddings[indices[:, 0]].cuda().float(),
                            self.sbert_embeddings[indices[:, 1]].cuda().float()
                        ).cpu().numpy()
                else:
                    sbert_sims = F.cosine_similarity(
                        self.sbert_embeddings[indices[:, 0]].float(),
                        self.sbert_embeddings[indices[:, 1]].float()
                    ).numpy()
    
                tfidf_sims = Parallel(n_jobs=min(psutil.cpu_count() - 1, 8), prefer='threads')(
                    delayed(compute_tfidf_pair)(self.tfidf_embeddings, idx1, idx2)
                    for idx1, idx2 in indices
                )
    
                # Find valid negative pairs
                valid_mask = (sbert_sims < self.sbert_threshold / 2) & \
                            (np.array(tfidf_sims) < self.tfidf_threshold / 2)
                
                valid_indices = indices[valid_mask]
                new_pairs = [
                    (self.df.iloc[idx1]['nct_id'], self.df.iloc[idx2]['nct_id'])
                    for idx1, idx2 in valid_indices
                ]
                
                # Convert to set for uniqueness check
                existing_pairs = set((p1, p2) for p1, p2 in negative_pairs)
                new_unique_pairs = [
                    pair for pair in new_pairs 
                    if pair not in existing_pairs and (pair[1], pair[0]) not in existing_pairs
                ]
                
                # Add new unique pairs
                negative_pairs.extend(new_unique_pairs[:num_pairs - len(negative_pairs)])
                pbar.update(min(len(new_unique_pairs), num_pairs - pbar.n))
                
                # Clear memory
                del sbert_sims, tfidf_sims
                gc.collect()
                
                # Break if we can't find any more valid pairs
                if len(new_unique_pairs) == 0 and len(negative_pairs) < num_pairs:
                    print(f"\nWarning: Could only generate {len(negative_pairs)} negative pairs")
                    break
    
        return negative_pairs[:num_pairs]

    def create_training_pairs(self, 
                              total_pairs: int = 100000,
                              validation_split: float = 0.2) -> Dict[str, pd.DataFrame]:
        print("\nCreating training pairs...")
        num_pos = min(int(total_pairs / (1 + self.pos_neg_ratio)), len(self.df) * (len(self.df) - 1) // 4)
        
        print(f"Generating positive pairs...")
        pos_pairs = self.generate_positive_pairs()
        pos_pairs = pos_pairs[:num_pos]
        actual_pos = len(pos_pairs)
        
        num_neg = int(actual_pos * self.pos_neg_ratio)
        print(f"Generating {num_neg} negative pairs...")
        neg_pairs = self.generate_negative_pairs(num_neg)
        
        print("\nCreating final datasets...")
        all_pairs = [(p1, p2, 1) for p1, p2 in pos_pairs] + \
                    [(p1, p2, 0) for p1, p2 in neg_pairs]
        np.random.shuffle(all_pairs)
        
        pair_df = pd.DataFrame(all_pairs, columns=['nct_id_1', 'nct_id_2', 'label'])
        
        print("Merging with original dataframe...")
        for suffix in ['1', '2']:
            pair_df = pair_df.merge(
                self.df.add_suffix(f'_{suffix}'),
                left_on=f'nct_id_{suffix}',
                right_on=f'nct_id_{suffix}',
                how='left'
            )
        
        # Split into train and validation
        mask = np.random.rand(len(pair_df)) > validation_split
        train_df = pair_df[mask]
        val_df = pair_df[~mask]
        
        print(f"Created {len(train_df)} training pairs and {len(val_df)} validation pairs")
        return {
            'train': train_df,
            'validation': val_df
        }

def load_embeddings(sbert_path: str, tfidf_path: str) -> Tuple[np.ndarray, csr_matrix]:
    """Load and verify embeddings from files."""
    try:
        print(f"Loading SBERT embeddings from {sbert_path}")
        sbert = np.load(sbert_path)
        print(f"SBERT embeddings shape: {sbert.shape}")
        
        print(f"Loading TF-IDF embeddings from {tfidf_path}")
        tfidf_sparse = np.load(tfidf_path, allow_pickle=True)
        tfidf = csr_matrix(
            (tfidf_sparse['data'], tfidf_sparse['indices'], tfidf_sparse['indptr']),
            shape=tuple(tfidf_sparse['shape'])
        )
        print(f"TF-IDF embeddings shape: {tfidf.shape}")
        
        if not np.issubdtype(sbert.dtype, np.number):
            raise ValueError(f"SBERT embeddings contain non-numeric data: {sbert.dtype}")
        if not np.issubdtype(tfidf.dtype, np.number):
            raise ValueError(f"TF-IDF embeddings contain non-numeric data: {tfidf.dtype}")
        
        if np.isnan(sbert).any() or np.isinf(sbert).any():
            raise ValueError("SBERT embeddings contain NaN or infinite values")
        if np.isnan(tfidf.data).any() or np.isinf(tfidf.data).any():
            raise ValueError("TF-IDF embeddings contain NaN or infinite values")
        
        return sbert, tfidf
    except Exception as e:
        print(f"Error loading embeddings: {str(e)}")
        raise

In [3]:
sbert_path = '/kaggle/input/nest-ps-1-sbert-embeddings/sbert_embeddings_final_20241228-170631.npy'
tfidf_path = '/kaggle/input/nest-ps-1-tfidf-embeddings/tfidf_vectors_final_20241228-170631.npz'

sbert, tfidf_sparse = load_embeddings(sbert_path, tfidf_path)

df = pd.read_csv('/kaggle/input/nest-ps-1-data-cleaning-done/clinical_trials_processed_with_missing_handled.csv')
print(f"DataFrame shape: {df.shape}")

Loading SBERT embeddings from /kaggle/input/nest-ps-1-sbert-embeddings/sbert_embeddings_final_20241228-170631.npy
SBERT embeddings shape: (393934, 768)
Loading TF-IDF embeddings from /kaggle/input/nest-ps-1-tfidf-embeddings/tfidf_vectors_final_20241228-170631.npz
TF-IDF embeddings shape: (393934, 5000)
DataFrame shape: (393934, 9)


In [None]:
# Decrease thresholds or batch sizes for memory constraints
SBERT_THRESHOLD = 0.75
TFIDF_THRESHOLD = 0.5
POS_NEG_RATIO = 1.0
TOTAL_PAIRS = 100000
VALIDATION_SPLIT = 0.2
BATCH_SIZE = 512  # More conservative batch size

torch.manual_seed(42)
np.random.seed(42)

pair_generator = DataPairGenerator(
    df=df,
    sbert_embeddings=sbert,
    tfidf_embeddings=tfidf_sparse,
    sbert_threshold=SBERT_THRESHOLD,
    tfidf_threshold=TFIDF_THRESHOLD,
    pos_neg_ratio=POS_NEG_RATIO,
    batch_size=BATCH_SIZE
)

print(f"\nGenerating training pairs (total pairs: {TOTAL_PAIRS})...")
try:
    with timer("Total dataset creation"):  # Changed from test_generator.timer
        pair_datasets = pair_generator.create_training_pairs(
            total_pairs=TOTAL_PAIRS,
            validation_split=VALIDATION_SPLIT
        )
        
    timestamp = pd.Timestamp.now().strftime("%Y%m%d-%H%M%S")
    print("\nSaving datasets...")
    train_path = f'training_pairs_{timestamp}.csv'
    pair_datasets['train'].to_csv(train_path, index=False)
    print(f"Training dataset saved to: {train_path}")
    
    val_path = f'validation_pairs_{timestamp}.csv'
    pair_datasets['validation'].to_csv(val_path, index=False)
    print(f"Validation dataset saved to: {val_path}")

    print("\nDataset Statistics:")
    print(f"Training pairs: {len(pair_datasets['train']):,}")
    print(f"Validation pairs: {len(pair_datasets['validation']):,}")
    
    train_pos = (pair_datasets['train']['label'] == 1).sum()
    train_neg = (pair_datasets['train']['label'] == 0).sum()
    val_pos = (pair_datasets['validation']['label'] == 1).sum()
    val_neg = (pair_datasets['validation']['label'] == 0).sum()
    
    print("\nClass Distribution:")
    print(f"Training - Positive: {train_pos:,} ({train_pos/len(pair_datasets['train'])*100:.1f}%)")
    print(f"Training - Negative: {train_neg:,} ({train_neg/len(pair_datasets['train'])*100:.1f}%)")
    print(f"Validation - Positive: {val_pos:,} ({val_pos/len(pair_datasets['validation'])*100:.1f}%)")
    print(f"Validation - Negative: {val_neg:,} ({val_neg/len(pair_datasets['validation'])*100:.1f}%)")

except Exception as e:
    print(f"\nError during pair generation: {str(e)}")
    raise

Initializing DataPairGenerator...
Loaded DataFrame with 393934 rows.
Loaded Tf-IDF embeddings with shape (393934, 5000).
Converting SBERT embeddings to PyTorch tensor...
Using GPU: Tesla P100-PCIE-16GB
Initialization complete.

Generating training pairs (total pairs: 100000)...

Creating training pairs...
Generating positive pairs...

Generating positive pairs...


Processing batches:   0%|          | 0/393934 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/79 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/197 [00:00<?, ?it/s]

In [11]:
# Example: select the first 1000 rows from the DataFrame
df_subset = df[:10000]

# For SBERT embeddings, slice to match the first 1000 rows
sbert_subset = sbert[:10000]

# For TF-IDF (csr_matrix), slice the rows to match the first 1000 as well
tfidf_subset = tfidf_sparse[:10000, :]

In [19]:
# Decrease thresholds or batch sizes for memory constraints
SBERT_THRESHOLD = 0.75
TFIDF_THRESHOLD = 0.5
POS_NEG_RATIO = 1.0
TOTAL_PAIRS = 100
VALIDATION_SPLIT = 0.2
BATCH_SIZE = 256  # More conservative batch size

torch.manual_seed(42)
np.random.seed(42)



pair_generator = DataPairGenerator(
    df=df_subset,
    sbert_embeddings=sbert_subset,
    tfidf_embeddings=tfidf_subset,
    sbert_threshold=SBERT_THRESHOLD,
    tfidf_threshold=TFIDF_THRESHOLD,
    pos_neg_ratio=POS_NEG_RATIO,
    batch_size=BATCH_SIZE
)

print(f"\nGenerating training pairs (total pairs: {TOTAL_PAIRS})...")
try:
    with timer("Total dataset creation"):  # Changed from test_generator.timer
        pair_datasets = pair_generator.create_training_pairs(
            total_pairs=TOTAL_PAIRS,
            validation_split=VALIDATION_SPLIT
        )
        
    timestamp = pd.Timestamp.now().strftime("%Y%m%d-%H%M%S")
    print("\nSaving datasets...")
    train_path = f'training_pairs_{timestamp}.csv'
    pair_datasets['train'].to_csv(train_path, index=False)
    print(f"Training dataset saved to: {train_path}")
    
    val_path = f'validation_pairs_{timestamp}.csv'
    pair_datasets['validation'].to_csv(val_path, index=False)
    print(f"Validation dataset saved to: {val_path}")

    print("\nDataset Statistics:")
    print(f"Training pairs: {len(pair_datasets['train']):,}")
    print(f"Validation pairs: {len(pair_datasets['validation']):,}")
    
    train_pos = (pair_datasets['train']['label'] == 1).sum()
    train_neg = (pair_datasets['train']['label'] == 0).sum()
    val_pos = (pair_datasets['validation']['label'] == 1).sum()
    val_neg = (pair_datasets['validation']['label'] == 0).sum()
    
    print("\nClass Distribution:")
    print(f"Training - Positive: {train_pos:,} ({train_pos/len(pair_datasets['train'])*100:.1f}%)")
    print(f"Training - Negative: {train_neg:,} ({train_neg/len(pair_datasets['train'])*100:.1f}%)")
    print(f"Validation - Positive: {val_pos:,} ({val_pos/len(pair_datasets['validation'])*100:.1f}%)")
    print(f"Validation - Negative: {val_neg:,} ({val_neg/len(pair_datasets['validation'])*100:.1f}%)")

except Exception as e:
    print(f"\nError during pair generation: {str(e)}")
    raise

Initializing DataPairGenerator...
Loaded DataFrame with 10000 rows.
Loaded Tf-IDF embeddings with shape (10000, 5000).
Converting SBERT embeddings to PyTorch tensor...
Using GPU: Tesla P100-PCIE-16GB
Initialization complete.

Generating training pairs (total pairs: 100)...

Creating training pairs...
Generating positive pairs...

Generating positive pairs...


Processing batches:   0%|          | 0/10000 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Computing SBERT similarities:   0%|          | 0/5 [00:00<?, ?it/s]

TF-IDF chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Generated 13980 positive pairs
Generating 50 negative pairs...
Aiming to generate 50 negative pairs (maximum possible: 49995000)

Generating negative pairs...


Finding negative pairs:   0%|          | 0/50 [00:00<?, ?it/s]



Creating final datasets...
Merging with original dataframe...
Created 172 training pairs and 40 validation pairs
Total dataset creation took 15.83 seconds

Saving datasets...
Training dataset saved to: training_pairs_20250101-052504.csv
Validation dataset saved to: validation_pairs_20250101-052504.csv

Dataset Statistics:
Training pairs: 172
Validation pairs: 40

Class Distribution:
Training - Positive: 172 (100.0%)
Training - Negative: 0 (0.0%)
Validation - Positive: 40 (100.0%)
Validation - Negative: 0 (0.0%)


In [1]:
import pandas as pd
train_df = pd.read_csv('/kaggle/input/train-val-pair-for-siamese-network/training_pairs_20250101-093156.csv')
val_df = pd.read_csv('/kaggle/input/train-val-pair-for-siamese-network/training_pairs_20250101-093156.csv')

In [9]:
train_df.tail()

Unnamed: 0,nct_id_1,nct_id_2,label,Study Title_Cleaned_1,Primary Outcome Measures_Cleaned_1,Secondary Outcome Measures_Cleaned_1,criteria_Cleaned_1,Secondary_Outcome_Missing_1,Primary_Outcome_Missing_1,Criteria_Missing_1,completeness_score_1,Study Title_Cleaned_2,Primary Outcome Measures_Cleaned_2,Secondary Outcome Measures_Cleaned_2,criteria_Cleaned_2,Secondary_Outcome_Missing_2,Primary_Outcome_Missing_2,Criteria_Missing_2,completeness_score_2
198496,NCT04465266,NCT04465266,1,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8
198497,NCT04465266,NCT04465266,1,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8
198498,NCT04465266,NCT04465266,1,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8
198499,NCT04465266,NCT04465266,1,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8
198500,NCT04465266,NCT04465266,1,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8


In [11]:
train_df['label'].value_counts()

label
1    198501
Name: count, dtype: int64

### Writing the model architecture

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

class SiameseNetwork(nn.Module):
    def __init__(self, model_name='allenai/scibert_scivocab_uncased', hidden_size=768):
        super(SiameseNetwork, self).__init__()
        
        # Initialize BERT encoder
        self.encoder = AutoModel.from_pretrained(model_name)
        
        # Projection layers
        self.projection = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size)
        )
        
    def forward_once(self, input_ids, attention_mask):
        # Get BERT outputs
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        # Use CLS token representation
        pooled = outputs.last_hidden_state[:, 0]
        
        # Project to final embedding
        projected = self.projection(pooled)
        return projected
        
    def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
        # Get embeddings for both trials
        emb1 = self.forward_once(input_ids1, attention_mask1)
        emb2 = self.forward_once(input_ids2, attention_mask2)
        return emb1, emb2


In [None]:
class HybridEncoder(nn.Module):
    def __init__(self, 
                 model_name='allenai/scibert_scivocab_uncased', 
                 text_dim=768,
                 tfidf_dim=5000,
                 numeric_dim=3):  # completeness_score + 2 missing flags
        super(HybridEncoder, self).__init__()
        
        # Text encoder (BERT)
        self.text_encoder = AutoModel.from_pretrained(model_name)
        
        # TF-IDF encoder
        self.tfidf_encoder = nn.Sequential(
            nn.Linear(tfidf_dim, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 128)
        )
        
        # Numeric features encoder
        self.numeric_encoder = nn.Sequential(
            nn.Linear(numeric_dim, 32),
            nn.LayerNorm(32),
            nn.ReLU(),
            nn.Linear(32, 16)
        )
        
        # Fusion layer
        combined_dim = text_dim + 128 + 16  # BERT + TF-IDF + numeric
        self.fusion = nn.Sequential(
            nn.Linear(combined_dim, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 256)
        )
        
    def forward_once(self, input_ids, attention_mask, tfidf_vector, numeric_features):
        # Get BERT embeddings
        text_outputs = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        text_embeds = text_outputs.last_hidden_state[:, 0]
        
        # Encode TF-IDF
        tfidf_embeds = self.tfidf_encoder(tfidf_vector)
        
        # Encode numeric features
        numeric_embeds = self.numeric_encoder(numeric_features)
        
        # Combine all features
        combined = torch.cat([text_embeds, tfidf_embeds, numeric_embeds], dim=1)
        
        # Final fusion
        fused = self.fusion(combined)
        return fused
        
    def forward(self, batch):
        # Process first trial
        emb1 = self.forward_once(
            batch['input_ids1'],
            batch['attention_mask1'],
            batch['tfidf1'],
            batch['numeric1']
        )
        
        # Process second trial
        emb2 = self.forward_once(
            batch['input_ids2'],
            batch['attention_mask2'],
            batch['tfidf2'],
            batch['numeric2']
        )
        
        return emb1, emb2


In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        
    def forward(self, emb1, emb2, label):
        # Calculate euclidean distance
        distance = F.pairwise_distance(emb1, emb2)
        
        # Contrastive loss
        loss = torch.mean((1 - label) * torch.pow(distance, 2) + 
                         label * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2))
        return loss

In [None]:
class TrialDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=512):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Tokenize text fields
        text1 = f"{row['Study Title_Cleaned_1']} {row['Primary Outcome Measures_Cleaned_1']} {row['criteria_Cleaned_1']}"
        text2 = f"{row['Study Title_Cleaned_2']} {row['Primary Outcome Measures_Cleaned_2']} {row['criteria_Cleaned_2']}"
        
        encoded1 = self.tokenizer(
            text1,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        encoded2 = self.tokenizer(
            text2,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        # Prepare numeric features
        numeric1 = torch.tensor([
            row['completeness_score_1'],
            row['Secondary_Outcome_Missing_1'],
            row['Primary_Outcome_Missing_1']
        ], dtype=torch.float)
        
        numeric2 = torch.tensor([
            row['completeness_score_2'],
            row['Secondary_Outcome_Missing_2'],
            row['Primary_Outcome_Missing_2']
        ], dtype=torch.float)
        
        return {
            'input_ids1': encoded1['input_ids'].squeeze(),
            'attention_mask1': encoded1['attention_mask'].squeeze(),
            'input_ids2': encoded2['input_ids'].squeeze(),
            'attention_mask2': encoded2['attention_mask'].squeeze(),
            'numeric1': numeric1,
            'numeric2': numeric2,
            'label': torch.tensor(row['label'], dtype=torch.float)
        }

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch in tqdm(train_loader, desc='Training'):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Forward pass
        emb1, emb2 = model(batch)
        
        # Compute loss
        loss = criterion(emb1, emb2, batch['label'])
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(train_loader)

In [1]:
import pandas as pd
train_df = pd.read_csv('/kaggle/input/train-val-pair-for-siamese-network/training_pairs_20250101-093156.csv')
val_df = pd.read_csv('/kaggle/input/train-val-pair-for-siamese-network/validation_pairs_20250101-093156.csv')

In [2]:
val_df['label'].value_counts()

label
1    49477
Name: count, dtype: int64

In [6]:
import numpy as np
from tqdm.notebook import tqdm

In [None]:
def generate_negative_pairs(df, num_negative_pairs):
    """Generate negative pairs through random sampling"""
    negative_pairs = []
    unique_trials = df['nct_id_1'].unique()
    
    while len(negative_pairs) < num_negative_pairs:
        # Randomly sample two different trials
        trial1, trial2 = np.random.choice(unique_trials, size=2, replace=False)
        
        # Check if this pair doesn't exist in positive pairs
        if not df[(df['nct_id_1'] == trial1) & (df['nct_id_2'] == trial2)].shape[0]:
            negative_pairs.append({
                'nct_id_1': trial1,
                'nct_id_2': trial2,
                'label': 0
            })
    
    return pd.DataFrame(negative_pairs)

# Create balanced dataset
num_positive = len(train_df)
negative_pairs_df = generate_negative_pairs(train_df, num_positive)
balanced_train_df = pd.concat([train_df, negative_pairs_df], ignore_index=True)
new_num_positive = len(val_df)
new_negative_pairs_df = generate_negative_pairs(val_df, new_num_positive)
balanced_val_df = pd.concat([val_df, new_negative_pairs_df], ignore_index=True)

In [8]:
def generate_negative_pairs(df, num_negative_pairs):
    """Generate negative pairs through random sampling"""
    negative_pairs = []
    unique_trials = df['nct_id_1'].unique()
    
    # Initialize tqdm with the total number of negative pairs needed
    with tqdm(total=num_negative_pairs, desc='Generating negative pairs') as pbar:
        while len(negative_pairs) < num_negative_pairs:
            # Randomly sample two different trials
            trial1, trial2 = np.random.choice(unique_trials, size=2, replace=False)
            
            # Check if this pair doesn't exist in positive pairs
            if not df[(df['nct_id_1'] == trial1) & (df['nct_id_2'] == trial2)].shape[0]:
                negative_pairs.append({
                    'nct_id_1': trial1,
                    'nct_id_2': trial2,
                    'label': 0
                })
                # Update the progress bar
                pbar.update(1)
    
    return pd.DataFrame(negative_pairs)

num_positive = len(train_df)
negative_pairs_df = generate_negative_pairs(train_df, num_positive)
balanced_train_df = pd.concat([train_df, negative_pairs_df], ignore_index=True)
new_num_positive = len(val_df)
new_negative_pairs_df = generate_negative_pairs(val_df, new_num_positive)
balanced_val_df = pd.concat([val_df, new_negative_pairs_df], ignore_index=True)

Generating negative pairs:   0%|          | 0/198501 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=10000.0 (msgs/sec)
NotebookApp.rate_limit_window=1.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=10000.0 (msgs/sec)
NotebookApp.rate_limit_window=1.0 (secs)



In [21]:
train_df.tail()

Unnamed: 0,nct_id_1,nct_id_2,label,Study Title_Cleaned_1,Primary Outcome Measures_Cleaned_1,Secondary Outcome Measures_Cleaned_1,criteria_Cleaned_1,Secondary_Outcome_Missing_1,Primary_Outcome_Missing_1,Criteria_Missing_1,completeness_score_1,Study Title_Cleaned_2,Primary Outcome Measures_Cleaned_2,Secondary Outcome Measures_Cleaned_2,criteria_Cleaned_2,Secondary_Outcome_Missing_2,Primary_Outcome_Missing_2,Criteria_Missing_2,completeness_score_2
198496,NCT04465266,NCT04465266,1,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8
198497,NCT04465266,NCT04465266,1,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8
198498,NCT04465266,NCT04465266,1,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8
198499,NCT04465266,NCT04465266,1,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8
198500,NCT04465266,NCT04465266,1,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8,a phase 1 pk study of tolperisone in healthy s...,cmax maximum plasma concentration of tolperiso...,no secondary outcomes provided,inclusion criteria generally healthy subjects ...,1,0,0,0.8


In [None]:
class ValidationMetrics:
    def __init__(self, k_values=[1, 5, 10]):
        self.k_values = k_values
    
    def precision_at_k(self, similarities, labels, k):
        """Calculate Precision@K"""
        _, top_k_indices = similarities.topk(k)
        relevant = labels[top_k_indices].float()
        return (relevant.sum() / k).mean().item()
    
    def mean_reciprocal_rank(self, similarities, labels):
        """Calculate MRR"""
        ranks = torch.where(labels[similarities.argsort(descending=True)])[0] + 1
        return (1.0 / ranks.float()).mean().item()
    
    def compute_metrics(self, embeddings1, embeddings2, labels):
        """Compute all validation metrics"""
        # Compute similarity matrix
        similarities = F.cosine_similarity(embeddings1.unsqueeze(1), 
                                        embeddings2.unsqueeze(0), dim=2)
        
        metrics = {
            'mrr': self.mean_reciprocal_rank(similarities, labels)
        }
        
        for k in self.k_values:
            metrics[f'p@{k}'] = self.precision_at_k(similarities, labels, k)
            
        return metrics

def validate(model, val_loader, criterion, metrics, device):
    """Validation step"""
    model.eval()
    total_loss = 0
    all_emb1, all_emb2, all_labels = [], [], []
    
    with torch.no_grad():
        for batch in val_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Forward pass
            emb1, emb2 = model(batch)
            
            # Compute loss
            loss = criterion(emb1, emb2, batch['label'])
            total_loss += loss.item()
            
            # Store embeddings and labels for metrics
            all_emb1.append(emb1)
            all_emb2.append(emb2)
            all_labels.append(batch['label'])
    
    # Concatenate all batches
    all_emb1 = torch.cat(all_emb1)
    all_emb2 = torch.cat(all_emb2)
    all_labels = torch.cat(all_labels)
    
    # Compute metrics
    validation_metrics = metrics.compute_metrics(all_emb1, all_emb2, all_labels)
    validation_metrics['loss'] = total_loss / len(val_loader)
    
    return validation_metrics

In [None]:
def train_model(model, train_loader, val_loader, config):
    """Training with hyperparameter configuration"""
    criterion = ContrastiveLoss(margin=config['margin'])
    optimizer = torch.optim.AdamW(model.parameters(), 
                                lr=config['learning_rate'],
                                weight_decay=config['weight_decay'])
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, verbose=True
    )
    
    metrics = ValidationMetrics(k_values=[1, 5, 10])
    best_val_loss = float('inf')
    patience = config['patience']
    patience_counter = 0
    
    for epoch in range(config['epochs']):
        # Training
        train_loss = train_epoch(model, train_loader, criterion, optimizer, config['device'])
        
        # Validation
        val_metrics = validate(model, val_loader, criterion, metrics, config['device'])
        
        # Learning rate scheduling
        scheduler.step(val_metrics['loss'])
        
        # Early stopping
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), 'best_model.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch}")
                break
        
        # Print metrics
        print(f"Epoch {epoch}:")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f}")
        print(f"Val MRR: {val_metrics['mrr']:.4f}")
        for k in metrics.k_values:
            print(f"Val P@{k}: {val_metrics[f'p@{k}']:.4f}")


In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import os
from datetime import datetime

def setup_training(config):
    # Set random seeds for reproducibility
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed(config['seed'])
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
    
    # Create datasets
    train_dataset = TrialDataset(balanced_train_df, tokenizer, max_length=config['max_length'])
    val_dataset = TrialDataset(val_df, tokenizer, max_length=config['max_length'])
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers']
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers']
    )
    
    # Initialize model
    model = HybridEncoder(
        model_name=config['model_name'],
        text_dim=config['text_dim'],
        tfidf_dim=config['tfidf_dim'],
        numeric_dim=config['numeric_dim']
    ).to(config['device'])
    
    return model, train_loader, val_loader

def save_checkpoint(model, optimizer, epoch, val_metrics, checkpoint_dir):
    """Save model checkpoint"""
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    checkpoint_path = os.path.join(
        checkpoint_dir, 
        f'checkpoint_epoch_{epoch}_{timestamp}.pt'
    )
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_metrics': val_metrics,
    }, checkpoint_path)
    
    return checkpoint_path

import json
import pandas as pd
from pathlib import Path

class MetricLogger:
    def __init__(self, log_dir):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(exist_ok=True)
        
        # Initialize containers for metrics
        self.metrics = {
            'train_loss': [],
            'val_loss': [],
            'val_mrr': [],
            'learning_rate': [],
            'epoch': []
        }
        # Add P@K metrics
        for k in [1, 5, 10]:
            self.metrics[f'val_p@{k}'] = []
            
    def log_metrics(self, epoch, train_loss, val_metrics, current_lr):
        """Log metrics for one epoch"""
        self.metrics['epoch'].append(epoch)
        self.metrics['train_loss'].append(train_loss)
        self.metrics['val_loss'].append(val_metrics['loss'])
        self.metrics['val_mrr'].append(val_metrics['mrr'])
        self.metrics['learning_rate'].append(current_lr)
        
        # Log P@K metrics
        for k in [1, 5, 10]:
            self.metrics[f'val_p@{k}'].append(val_metrics[f'p@{k}'])
    
    def save_metrics(self):
        """Save metrics to files"""
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        
        # Save as CSV
        df = pd.DataFrame(self.metrics)
        csv_path = self.log_dir / f'metrics_{timestamp}.csv'
        df.to_csv(csv_path, index=False)
        
        # Save as JSON
        json_path = self.log_dir / f'metrics_{timestamp}.json'
        with open(json_path, 'w') as f:
            json.dump(self.metrics, f, indent=4)
        
        return csv_path, json_path

def train_model(model, train_loader, val_loader, config):
    """Training with metric logging"""
    # Initialize metric logger
    logger = MetricLogger(config['log_dir'])
    
    criterion = ContrastiveLoss(margin=config['margin'])
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, verbose=True
    )
    
    metrics = ValidationMetrics(k_values=[1, 5, 10])
    best_val_loss = float('inf')
    best_model_path = None
    
    for epoch in range(config['epochs']):
        print(f"\nEpoch {epoch+1}/{config['epochs']}")
        
        # Training phase
        train_loss = train_epoch(model, train_loader, criterion, optimizer, config['device'])
        
        # Validation phase
        val_metrics = validate(model, val_loader, criterion, metrics, config['device'])
        
        # Log metrics
        current_lr = optimizer.param_groups[0]['lr']
        logger.log_metrics(epoch, train_loss, val_metrics, current_lr)
        
        # Learning rate scheduling
        scheduler.step(val_metrics['loss'])
        
        # Save checkpoint
        checkpoint_path = save_checkpoint(
            model, optimizer, epoch, val_metrics, config['checkpoint_dir']
        )
        
        # Save best model
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            if best_model_path and os.path.exists(best_model_path):
                os.remove(best_model_path)
            best_model_path = os.path.join(config['checkpoint_dir'], 'best_model.pt')
            torch.save(model.state_dict(), best_model_path)
        
        # Print metrics
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Validation Loss: {val_metrics['loss']:.4f}")
        print(f"Validation MRR: {val_metrics['mrr']:.4f}")
        for k in metrics.k_values:
            print(f"Validation P@{k}: {val_metrics[f'p@{k}']:.4f}")
        
        # Early stopping check
        if val_metrics['loss'] > best_val_loss * (1 + config['early_stopping_threshold']):
            config['patience_counter'] += 1
            if config['patience_counter'] >= config['patience']:
                print("Early stopping triggered")
                break
        else:
            config['patience_counter'] = 0
    
    # Save final metrics
    csv_path, json_path = logger.save_metrics()
    print(f"\nMetrics saved to:")
    print(f"CSV: {csv_path}")
    print(f"JSON: {json_path}")

# Update config with log directory
config.update({
    'log_dir': 'training_logs'
})


# Configuration
config = {
    'model_name': 'allenai/scibert_scivocab_uncased',
    'text_dim': 768,
    'tfidf_dim': 5000,  # Adjust based on your TF-IDF vector size
    'numeric_dim': 3,
    'max_length': 512,
    'batch_size': 16,
    'num_workers': 4,
    'learning_rate': 2e-5,
    'weight_decay': 0.01,
    'margin': 1.0,
    'epochs': 10,
    'patience': 3,
    'early_stopping_threshold': 0.01,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'checkpoint_dir': 'checkpoints',
    'seed': 42,
    'patience_counter': 0
}

# Run training
if __name__ == "__main__":
    print("Initializing training...")
    model, train_loader, val_loader = setup_training(config)
    
    print("Starting training...")
    train_model(model, train_loader, val_loader, config)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import os
import numpy as np
from datetime import datetime
import json
import pandas as pd
from pathlib import Path
from tqdm import tqdm

class TrialDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=512):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Combine text fields
        text1 = f"{row['Study Title_Cleaned_1']} {row['Primary Outcome Measures_Cleaned_1']} {row['criteria_Cleaned_1']}"
        text2 = f"{row['Study Title_Cleaned_2']} {row['Primary Outcome Measures_Cleaned_2']} {row['criteria_Cleaned_2']}"
        
        # Tokenize texts
        encoded1 = self.tokenizer(
            text1,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        encoded2 = self.tokenizer(
            text2,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        # Prepare numeric features
        numeric1 = torch.tensor([
            row['completeness_score_1'],
            row['Secondary_Outcome_Missing_1'],
            row['Primary_Outcome_Missing_1']
        ], dtype=torch.float)
        
        numeric2 = torch.tensor([
            row['completeness_score_2'],
            row['Secondary_Outcome_Missing_2'],
            row['Primary_Outcome_Missing_2']
        ], dtype=torch.float)
        
        return {
            'input_ids1': encoded1['input_ids'].squeeze(),
            'attention_mask1': encoded1['attention_mask'].squeeze(),
            'input_ids2': encoded2['input_ids'].squeeze(),
            'attention_mask2': encoded2['attention_mask'].squeeze(),
            'numeric1': numeric1,
            'numeric2': numeric2,
            'label': torch.tensor(row['label'], dtype=torch.float)
        }

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        
    def forward(self, emb1, emb2, label):
        distance = F.pairwise_distance(emb1, emb2)
        loss = torch.mean((1 - label) * torch.pow(distance, 2) + 
                         label * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2))
        return loss

class ValidationMetrics:
    def __init__(self, k_values=[1, 5, 10]):
        self.k_values = k_values
    
    def precision_at_k(self, similarities, labels, k):
        _, top_k_indices = similarities.topk(k)
        relevant = labels[top_k_indices].float()
        return (relevant.sum() / k).mean().item()
    
    def mean_reciprocal_rank(self, similarities, labels):
        ranks = torch.where(labels[similarities.argsort(descending=True)])[0] + 1
        return (1.0 / ranks.float()).mean().item()
    
    def compute_metrics(self, embeddings1, embeddings2, labels):
        similarities = F.cosine_similarity(embeddings1.unsqueeze(1), 
                                        embeddings2.unsqueeze(0), dim=2)
        
        metrics = {
            'mrr': self.mean_reciprocal_rank(similarities, labels)
        }
        
        for k in self.k_values:
            metrics[f'p@{k}'] = self.precision_at_k(similarities, labels, k)
            
        return metrics

class MetricLogger:
    def __init__(self, log_dir):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(exist_ok=True)
        
        self.metrics = {
            'train_loss': [],
            'val_loss': [],
            'val_mrr': [],
            'learning_rate': [],
            'epoch': []
        }
        for k in [1, 5, 10]:
            self.metrics[f'val_p@{k}'] = []
            
    def log_metrics(self, epoch, train_loss, val_metrics, current_lr):
        self.metrics['epoch'].append(epoch)
        self.metrics['train_loss'].append(train_loss)
        self.metrics['val_loss'].append(val_metrics['loss'])
        self.metrics['val_mrr'].append(val_metrics['mrr'])
        self.metrics['learning_rate'].append(current_lr)
        
        for k in [1, 5, 10]:
            self.metrics[f'val_p@{k}'].append(val_metrics[f'p@{k}'])
    
    def save_metrics(self):
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        
        df = pd.DataFrame(self.metrics)
        csv_path = self.log_dir / f'metrics_{timestamp}.csv'
        df.to_csv(csv_path, index=False)
        
        json_path = self.log_dir / f'metrics_{timestamp}.json'
        with open(json_path, 'w') as f:
            json.dump(self.metrics, f, indent=4)
        
        return csv_path, json_path

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch in tqdm(train_loader, desc='Training'):
        try:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            optimizer.zero_grad()
            emb1, emb2 = model(batch)
            loss = criterion(emb1, emb2, batch['label'])
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            
        except RuntimeError as e:
            print(f"Error in batch: {str(e)}")
            continue
            
    return total_loss / len(train_loader)

def validate(model, val_loader, criterion, metrics, device):
    model.eval()
    total_loss = 0
    all_emb1, all_emb2, all_labels = [], [], []
    
    with torch.no_grad():
        for batch in val_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            emb1, emb2 = model(batch)
            loss = criterion(emb1, emb2, batch['label'])
            
            total_loss += loss.item()
            all_emb1.append(emb1.cpu())
            all_emb2.append(emb2.cpu())
            all_labels.append(batch['label'].cpu())
            
            torch.cuda.empty_cache()
    
    all_emb1 = torch.cat(all_emb1)
    all_emb2 = torch.cat(all_emb2)
    all_labels = torch.cat(all_labels)
    
    validation_metrics = metrics.compute_metrics(all_emb1, all_emb2, all_labels)
    validation_metrics['loss'] = total_loss / len(val_loader)
    
    return validation_metrics

def setup_training(config):
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed(config['seed'])
    
    tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
    
    train_dataset = TrialDataset(balanced_train_df, tokenizer, max_length=config['max_length'])
    val_dataset = TrialDataset(val_df, tokenizer, max_length=config['max_length'])
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers']
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers']
    )
    
    model = HybridEncoder(
        model_name=config['model_name'],
        text_dim=config['text_dim'],
        tfidf_dim=config['tfidf_dim'],
        numeric_dim=config['numeric_dim']
    ).to(config['device'])
    
    return model, train_loader, val_loader

def save_checkpoint(model, optimizer, epoch, val_metrics, checkpoint_dir):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    checkpoint_path = os.path.join(
        checkpoint_dir, 
        f'checkpoint_epoch_{epoch}_{timestamp}.pt'
    )
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_metrics': val_metrics,
    }, checkpoint_path)
    
    return checkpoint_path

def train_model(model, train_loader, val_loader, config):
    try:
        logger = MetricLogger(config['log_dir'])
        
        criterion = ContrastiveLoss(margin=config['margin'])
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=2, verbose=True
        )
        
        metrics = ValidationMetrics(k_values=[1, 5, 10])
        best_val_loss = float('inf')
        best_model_path = None
        
        for epoch in range(config['epochs']):
            print(f"\nEpoch {epoch+1}/{config['epochs']}")
            
            train_loss = train_epoch(model, train_loader, criterion, optimizer, config['device'])
            val_metrics = validate(model, val_loader, criterion, metrics, config['device'])
            
            current_lr = optimizer.param_groups[0]['lr']
            logger.log_metrics(epoch, train_loss, val_metrics, current_lr)
            
            scheduler.step(val_metrics['loss'])
            
            checkpoint_path = save_checkpoint(
                model, optimizer, epoch, val_metrics, config['checkpoint_dir']
            )
            
            if val_metrics['loss'] < best_val_loss:
                best_val_loss = val_metrics['loss']
                if best_model_path and os.path.exists(best_model_path):
                    os.remove(best_model_path)
                best_model_path = os.path.join(config['checkpoint_dir'], 'best_model.pt')
                torch.save(model.state_dict(), best_model_path)
            
            print(f"Train Loss: {train_loss:.4f}")
            print(f"Validation Loss: {val_metrics['loss']:.4f}")
            print(f"Validation MRR: {val_metrics['mrr']:.4f}")
            for k in metrics.k_values:
                print(f"Validation P@{k}: {val_metrics[f'p@{k}']:.4f}")
            
            if val_metrics['loss'] > best_val_loss * (1 + config['early_stopping_threshold']):
                config['patience_counter'] += 1
                if config['patience_counter'] >= config['patience']:
                    print("Early stopping triggered")
                    break
            else:
                config['patience_counter'] = 0
        
        csv_path, json_path = logger.save_metrics()
        print(f"\nMetrics saved to:")
        print(f"CSV: {csv_path}")
        print(f"JSON: {json_path}")
        
    except Exception as e:
        print(f"Error during training: {str(e)}")
        if 'model' in locals() and 'optimizer' in locals():
            save_checkpoint(model, optimizer, epoch, val_metrics, config['checkpoint_dir'])
        raise e

config = {
    'model_name': 'allenai/scibert_scivocab_uncased',
    'text_dim': 768,
    'tfidf_dim': 5000,
    'numeric_dim': 3,
    'max_length': 512,
    'batch_size': 16,
    'num_workers': 4,
    'learning_rate': 2e-5,
    'weight_decay': 0.01,
    'margin': 1.0,
    'epochs': 10,
    'patience': 3,
    'early_stopping_threshold': 0.01,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'checkpoint_dir': 'checkpoints',
    'log_dir': 'training_logs',
    'seed': 42,
    'patience_counter': 0,
    'grad_clip': 1.0
}

if __name__ == "__main__":
    print("Initializing training...")
    model, train_loader, val_loader = setup_training(config)
    
    print("Starting training...")
    train_model(model, train_loader, val_loader, config)
