# Model Training & Data Ingestion for RAG System

This notebook prepares the product data and ingests it into the Weaviate vector database. This is the core 'data modeling' step for our Retrieval-Augmented Generation (RAG) system. 

**Workflow:**
1.  **Load Data:** Read the `intern_data_ikarus.csv` file.
2.  **Preprocess Data:** Clean and prepare the text and image data.
3.  **Generate Embeddings:** Create vector representations (embeddings) for each product using:
    * **Text Embeddings:** Using a `sentence-transformers` model.
    * **Image Embeddings:** Using a pre-trained `ResNet-50` model to get embeddings from the first image of each product.
4.  **Connect to Weaviate:** Establish a connection with the local Weaviate instance.
5.  **Define Schema:** Create a 'Product' collection schema in Weaviate.
6.  **Ingest Data:** Batch-upload the product data along with its generated vector embedding into Weaviate.

### 1. Load Libraries and Data

In [None]:
import pandas as pd
import weaviate
import ast
import requests
from PIL import Image
from io import BytesIO
import torch
from torchvision import transforms
from torchvision.models import resnet50
from sentence_transformers import SentenceTransformer
import numpy as np
from tqdm import tqdm

In [None]:
# Load the dataset
file_path = '../backend/data/intern_data_ikarus.csv'
df = pd.read_csv(file_path).dropna(subset=['title', 'images'])
df = df.head(500) # For faster processing during development, limit to 500 records

### 2. Initialize Models for Embedding Generation

In [None]:
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Text Embedding Model ---
text_embedder = SentenceTransformer('all-MiniLM-L6-v2', device=device)

# --- Image Embedding Model (ResNet50) ---
image_model = resnet50(pretrained=True)
image_model = torch.nn.Sequential(*(list(image_model.children())[:-1])) # Remove the final classification layer
image_model.eval()
image_model.to(device)

# Define image transformations
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

### 3. Define Embedding Helper Functions

In [None]:
def get_image_embedding(image_url):
    try:
        response = requests.get(image_url.strip(), timeout=10)
        img = Image.open(BytesIO(response.content)).convert('RGB')
        img_t = preprocess(img)
        batch_t = torch.unsqueeze(img_t, 0).to(device)
        with torch.no_grad():
            embedding = image_model(batch_t)
        return embedding.squeeze().cpu().numpy()
    except Exception as e:
        # print(f"Error processing image {image_url}: {e}")
        return np.zeros(2048) # Return a zero vector on error

def get_text_embedding(text):
    return text_embedder.encode(text)

### 4. Connect to Weaviate and Define Schema

In [None]:
# Connect to Weaviate client
client = weaviate.Client("http://localhost:8080")

# Define the class schema
class_name = "Product"
class_obj = {
    "class": class_name,
    "vectorizer": "none",  # We are providing our own vectors
    "properties": [
        {"name": "title", "dataType": ["text"]},
        {"name": "brand", "dataType": ["text"]},
        {"name": "description", "dataType": ["text"]},
        {"name": "price", "dataType": ["text"]},
        {"name": "material", "dataType": ["text"]},
        {"name": "image_url", "dataType": ["text"]}
    ]
}

# Clean up previous schema if it exists
if client.schema.exists(class_name):
    client.schema.delete_class(class_name)

# Create the new schema
client.schema.create_class(class_obj)
print("Schema created successfully.")

### 5. Generate Embeddings and Ingest Data into Weaviate

This is the main processing loop. We iterate through each product, generate a combined text and image embedding, and then upload it to Weaviate in batches.

In [None]:
# Configure batching
client.batch.configure(batch_size=100)

with client.batch as batch:
    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
        # ---- Create the text for embedding ----
        text_to_embed = f"Title: {row['title']}. Description: {row.get('description', '')}. Brand: {row.get('brand', '')}."
        
        # ---- Get text embedding ----
        text_vector = get_text_embedding(text_to_embed)
        
        # ---- Get image embedding ----
        image_urls = ast.literal_eval(row['images'])
        first_image_url = image_urls[0] if image_urls else None
        image_vector = get_image_embedding(first_image_url) if first_image_url else np.zeros(2048)
        
        # ---- Combine embeddings (simple concatenation) ----
        combined_vector = np.concatenate([text_vector, image_vector])

        # ---- Prepare data object for Weaviate ----
        properties = {
            "title": row['title'],
            "brand": str(row.get('brand', 'N/A')),
            "description": str(row.get('description', 'No description available.')),
            "price": str(row.get('price', '$0.00')),
            "material": str(row.get('material', 'N/A')),
            "image_url": str(first_image_url)
        }
        
        # Add object to batch
        batch.add_data_object(
            data_object=properties,
            class_name=class_name,
            vector=combined_vector.tolist() # Add the combined vector
        )

print("Data ingestion complete!")