# Food Recognition and Weight Estimation Model

This notebook demonstrates how to use the food recognition and weight estimation model. The model uses a multi-task learning approach with a ResNet50 backbone to simultaneously perform food classification and weight estimation.

## 1. Setup

First, let's install the required packages:

In [None]:
!pip install torch torchvision pandas numpy pillow scikit-learn matplotlib tqdm

## 2. Import the Model

The model has been combined into a single file `kaggle_model.py`. Let's import it:

In [None]:
from kaggle_model import MultiTaskNet, FoodDataset, prepare_data, parse_args, train_model, get_device
import torch
import matplotlib.pyplot as plt
import os
import numpy as np
from PIL import Image
import torchvision.transforms as transforms

## 3. Inference Function

Let's define a function for inference that can be used with sample images:

In [None]:
def process_image(image_path, model, idx_to_label, device):
    """Process a single image and return the predictions."""
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    image = Image.open(image_path).convert('RGB')
    original_image = image.copy()
    
    # Prepare image for model
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    # Get predictions
    with torch.no_grad():
        class_logits, weight_pred = model(input_tensor)
    
    # Process classification results
    probabilities = torch.nn.functional.softmax(class_logits, dim=1)
    top_prob, top_class = torch.topk(probabilities, k=3, dim=1)
    
    # Convert to numpy for easier handling
    top_classes = top_class.squeeze().cpu().numpy()
    top_probs = top_prob.squeeze().cpu().numpy()
    
    # Get the food labels and probabilities
    food_predictions = [
        {
            "label": idx_to_label[idx], 
            "probability": float(prob)
        }
        for idx, prob in zip(top_classes, top_probs)
    ]
    
    # Get weight prediction
    predicted_weight = float(weight_pred.item())
    
    results = {
        "food_predictions": food_predictions,
        "predicted_weight": predicted_weight,
        "image_path": image_path
    }
    
    return results, original_image

def visualize_results(image, results):
    """Visualize the inference results."""
    plt.figure(figsize=(10, 8))
    plt.imshow(image)
    plt.axis('off')
    
    # Add text overlay with predictions
    prediction_text = f"Food: {results['food_predictions'][0]['label']} ({results['food_predictions'][0]['probability']:.2f})\n"
    prediction_text += f"Weight: {results['predicted_weight']:.1f}g"
    
    plt.text(10, 30, prediction_text, color='white', fontsize=12, 
             bbox=dict(facecolor='black', alpha=0.7))
    
    plt.tight_layout()
    plt.show()

## 4. Loading the Model

Now let's load a pretrained model if available, or train a new one:

In [None]:
# Set device
device = get_device()
print(f"Using device: {device}")

# Define path to the model (upload your pretrained model to Kaggle or use a new one)
model_path = "../input/food-model/best_model.pth"  # Adjust this path based on your Kaggle dataset

try:
    # Try to load the model
    checkpoint = torch.load(model_path, map_location=device)
    label_to_idx = checkpoint.get('label_to_idx', {})
    idx_to_label = {v: k for k, v in label_to_idx.items()}
    
    num_classes = len(label_to_idx)
    model = MultiTaskNet(num_classes)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print(f"Model loaded successfully with {num_classes} food classes")
except Exception as e:
    print(f"Could not load model: {e}")
    print("You'll need to train a model or upload a pretrained one")

## 5. Inference with Sample Images

Let's use our model for inference on some sample images:

In [None]:
# Sample image path - replace with your own image path
sample_image_path = "../input/food-images/pasta.jpg"  # Adjust this path

try:
    results, image = process_image(sample_image_path, model, idx_to_label, device)
    visualize_results(image, results)
    
    print("\nTop 3 predictions:")
    for pred in results["food_predictions"]:
        print(f"- {pred['label']}: {pred['probability']:.4f}")
    print(f"Estimated weight: {results['predicted_weight']:.1f}g")
except Exception as e:
    print(f"Error processing image: {e}")

## 6. Training a New Model

If you have your own dataset, you can train a new model:

In [None]:
# Define paths for your Kaggle dataset
csv_path = "../input/food-dataset/food_labels.csv"  # Adjust based on your data
images_dir = "../input/food-dataset/images"  # Adjust based on your data
model_dir = "./models"
os.makedirs(model_dir, exist_ok=True)

# Define training parameters
batch_size = 16
num_workers = 2
num_epochs = 5  # Reduced for demonstration purposes

# Load and prepare data
try:
    train_dataloader, val_dataloader, label_to_idx = prepare_data(
        csv_path, images_dir, batch_size=batch_size, num_workers=num_workers
    )
    
    # Initialize model
    num_classes = len(label_to_idx)
    model = MultiTaskNet(num_classes)
    model.to(device)
    
    # Define args object for the training function
    class Args:
        def __init__(self):
            self.lr = 1e-4
    args = Args()
    
    # Train the model
    training_logs = train_model(
        model, 
        train_dataloader, 
        val_dataloader, 
        device, 
        num_epochs, 
        model_dir
    )
    
    print("Training completed!")
except Exception as e:
    print(f"Error during training: {e}")

## 7. Visualize Training Results

If you've trained a model, let's visualize the training metrics:

In [None]:
try:
    # Plot training and validation loss
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(training_logs['epochs'], training_logs['train_loss'], label='Train Loss')
    plt.plot(training_logs['epochs'], training_logs['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(training_logs['epochs'], training_logs['val_accuracy'], label='Accuracy')
    plt.plot(training_logs['epochs'], training_logs['weight_mae'], label='Weight MAE (g)')
    plt.xlabel('Epoch')
    plt.ylabel('Metric Value')
    plt.title('Validation Metrics')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
except:
    print("No training logs available to visualize")