In [None]:
# train.py
import os
import pandas as pd
import numpy as np
import faiss
from transformers import AutoTokenizer, AutoModel
import torch
import yaml

# Load config file
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

data_source = config['data_source']
output_index = config['output_index']
output_embeddings = config['output_embeddings']
model_name = config['model_name']
model_directory = config['model_directory']  # Parametrized model storage directory
clean_up_tokenization_spaces = config['clean_up_tokenization_spaces']  # Parametrized tokenizer cleanup

# Load data
df = pd.read_csv(data_source)

# Check if the model is in the folder, else download it
if not os.path.exists(model_directory):
    # Load the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=clean_up_tokenization_spaces)
    model = AutoModel.from_pretrained(model_name)

    # Save your files to the specified directory with PreTrainedModel.save_pretrained():
    tokenizer.save_pretrained(os.path.join(model_directory, "tokenizer"))
    model.save_pretrained(os.path.join(model_directory, "model"))
else:
    # Now when you're offline, reload your files with PreTrainedModel.from_pretrained() from the specified directory:
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_directory, "tokenizer"))
    model = AutoModel.from_pretrained(os.path.join(model_directory, "model"))

# Generate embeddings
def get_embeddings(text):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()

df['embeddings'] = df['text'].apply(lambda x: get_embeddings(x))

# Create FAISS index
embeddings_matrix = np.stack(df['embeddings'].values)
index = faiss.IndexFlatL2(embeddings_matrix.shape[1])
index.add(embeddings_matrix)

# Save FAISS index and DataFrame with embeddings
faiss.write_index(index, output_index)
df.to_pickle(output_embeddings)

print("Training complete. FAISS index and embeddings saved locally.")
