In [None]:
import pandas as pd
import numpy as np
import ast
import requests
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import seaborn as sns
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Load data
DATA_PATH = "../backend/app/data_ingestion/sample_data.csv"
df = pd.read_csv(DATA_PATH)
df = df.fillna("")

# Clean categories and images
def safe_literal_eval(val):
    if not isinstance(val, str) or not val.startswith('['):
        return []
    try:
        return ast.literal_eval(val)
    except:
        return []

df['categories_clean'] = df['categories'].apply(safe_literal_eval)
df['images_clean'] = df['images'].apply(safe_literal_eval)

# Get the first image URL
df['image_url'] = df['images_clean'].apply(lambda x: x[0] if len(x) > 0 else None)

# --- Map complex categories to simple ones --- 
def map_category(cats):
    for cat in cats:
        cat = str(cat).lower()
        if 'chair' in cat:
            return 'Chair'
        if 'table' in cat or 'desk' in cat:
            return 'Table/Desk'
        if 'bookcase' in cat or 'tv stand' in cat:
            return 'Storage'
        if 'bed' in cat:
            return 'Bed'
    return 'Other'

df['simple_category'] = df['categories_clean'].apply(map_category)

# Filter for items that have an image and a non-'Other' category
df_cv = df[~df['image_url'].isnull() & (df['simple_category'] != 'Other')].copy()

# Create integer labels for our categories
labels = df_cv['simple_category'].unique()
label_map = {label: i for i, label in enumerate(labels)}
label_map_inv = {i: label for label, i in label_map.items()}
df_cv['label'] = df_cv['simple_category'].map(label_map)

N_CLASSES = len(labels)
print(f"Found {len(df_cv)} usable images across {N_CLASSES} classes.")
print(f"Classes: {label_map}")
df_cv[['title', 'image_url', 'simple_category', 'label']].head()

In [None]:
# Define transformations for ResNet18
data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

class FurnitureDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_url = row['image_url']
        label = torch.tensor(row['label'], dtype=torch.long)

        try:
            # Download and open image
            headers = {'User-Agent': 'Mozilla/5.0'}
            response = requests.get(image_url, timeout=5, headers=headers)
            response.raise_for_status() # Raise error for bad responses
            image = Image.open(BytesIO(response.content)).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
                
            return image, label
        
        except Exception as e:
            # print(f"Warning: Could not load image {image_url}. {e}")
            # Return a blank tensor and a special label (-1) to skip it
            return torch.zeros((3, 224, 224)), torch.tensor(-1, dtype=torch.long)

# Split the data (using a small test set for demo)
train_df, test_df = train_test_split(df_cv, test_size=0.4, random_state=42, stratify=df_cv['label'])

train_dataset = FurnitureDataset(train_df, transform=data_transforms)
test_dataset = FurnitureDataset(test_df, transform=data_transforms)

# Custom collate_fn to filter out failed image loads
def collate_fn(batch):
    batch = [(img, lbl) for img, lbl in batch if lbl != -1]
    if not batch:
        return torch.empty(0), torch.empty(0)
    return torch.utils.data.dataloader.default_collate(batch)

# Create DataLoaders
BATCH_SIZE = 2 # Very small for demo
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

print(f"Created DataLoaders. Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")

In [None]:
# Load a pre-trained ResNet18 model
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False

# Get the number of input features for the final layer
num_ftrs = model.fc.in_features

# Replace the final fully connected layer
model.fc = nn.Linear(num_ftrs, N_CLASSES)

# Move the model to the GPU if available
model = model.to(device)

print(f"Model loaded. Final layer replaced to output {N_CLASSES} classes.")

In [None]:
criterion = nn.CrossEntropyLoss()
# We only optimize the parameters of the new final layer
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

NUM_EPOCHS = 3 # Keep low for demo

print("Starting training...")

for epoch in range(NUM_EPOCHS):
    model.train() # Set model to training mode
    running_loss = 0.0
    
    for images, labels in train_loader:
        if images.nelement() == 0: continue # Skip empty batches (from failed downloads)
        
        images, labels = images.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    
    # We check len(train_dataset.df) because len(train_dataset) is affected by failed loads
    epoch_loss = running_loss / len(train_dataset.df)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {epoch_loss:.4f}")

print("Training complete.")

In [None]:
model.eval() # Set model to evaluation mode
all_preds = []
all_labels = []

print("Running evaluation...")
with torch.no_grad(): # No gradients needed
    for images, labels in test_loader:
        if images.nelement() == 0: continue
        
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)
        _, preds = torch.max(outputs, 1) # Get the index of the max logit
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

print("Evaluation complete.")

In [None]:
# 1. Classification Report (Precision, Recall, F1-Score)
if len(all_labels) > 0:
    # Ensure all labels are present in the map for the report
    unique_labels_in_data = [l for l in np.unique(all_labels + all_preds) if l in label_map_inv]
    target_names = [label_map_inv[i] for i in unique_labels_in_data]
    
    print("Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=target_names, zero_division=0))
else:
    print("No valid test data to evaluate.")

print("\nEvaluation: This report shows the model's performance on the test set.")
print("(Note: Metrics will be poor due to the tiny dataset, but this demonstrates the process)")

In [None]:
# 2. Confusion Matrix
if len(all_labels) > 0:
    unique_labels_in_data = [l for l in np.unique(all_labels + all_preds) if l in label_map_inv]
    target_names = [label_map_inv[i] for i in unique_labels_in_data]
    
    cm = confusion_matrix(all_labels, all_preds, labels=unique_labels_in_data)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=target_names, 
                yticklabels=target_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.show()
else:
    print("Cannot show confusion matrix: no test data.")
    
print("\nEvaluation: This matrix shows which classes the model is confusing (e.t., predicting 'Table' when it was a 'Chair').")

In [None]:
# Ensure the artifacts directory exists
os.makedirs("./artifacts", exist_ok=True)
MODEL_SAVE_PATH = "./artifacts/cv_category_model.pth"

torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"Model state saved to {MODEL_SAVE_PATH}")