In [1]:
%pip install timm torch torchvision gradio datasets pillow numpy

Note: you may need to restart the kernel to use updated packages.


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from datasets import load_from_disk
from PIL import Image
import os
import io
import numpy as np

# Set the paths for your local data and model
DATA_DIR = "food _classifier/DATA_DIR"
MODEL_PATH = "food _classifier/food_classifier_cnn_lstm.pth"
NEW_MODEL_PATH = "food_classifier_cnn_lstm.pth"

# Load the datasets
train_dataset = load_from_disk(os.path.join(DATA_DIR, "train"))
val_dataset = load_from_disk(os.path.join(DATA_DIR, "validation"))

# Custom Dataset class
class FoodDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.hf_dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        item = self.hf_dataset[idx]
        image = item['image']
    
    # If the image is not already in RGB, convert it
        if image.mode != 'RGB':
            image = image.convert('RGB')
    
        label = item['label']
    
        if self.transform:
            image = self.transform(image)
    
        return image, label

# Define the model
class FoodClassifierCNNLSTM(nn.Module):
    def __init__(self, num_classes, hidden_size=256, num_layers=2):
        super(FoodClassifierCNNLSTM, self).__init__()
        self.resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-2])
        
        self.lstm = nn.LSTM(2048, hidden_size, num_layers, batch_first=True)
        self.layer_norm = nn.LayerNorm(hidden_size)  # Added LayerNorm
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        batch_size = x.size(0)
        cnn_features = self.resnet(x)  # Shape: (batch_size, 2048, H, W)
        cnn_features = cnn_features.view(batch_size, 2048, -1).permute(0, 2, 1)  # Shape: (batch_size, seq_len, 2048)
        lstm_out, _ = self.lstm(cnn_features)  # Shape: (batch_size, seq_len, hidden_size)
        lstm_out = self.layer_norm(lstm_out)   # Apply LayerNorm
        lstm_out = lstm_out[:, -1, :]         # Use the last output of LSTM
        output = self.fc(lstm_out)            # Final classification layer
        return output

# Temporary model for loading pre-trained weights
class TempFoodClassifier(nn.Module):
    def __init__(self, num_classes):
        super(TempFoodClassifier, self).__init__()
        self.resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.resnet(x)

# Set up data transformations
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = FoodDataset(train_dataset, transform=data_transforms)
val_dataset = FoodDataset(val_dataset, transform=data_transforms)

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32)

# Initialize the model
num_classes = len(set(train_dataset.hf_dataset['label']))
model = FoodClassifierCNNLSTM(num_classes)

# Determine the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the previously trained weights
pretrained_weights = torch.load(MODEL_PATH, map_location=device)

# Create a temporary model with the same architecture as the one used for training
temp_model = TempFoodClassifier(num_classes)
temp_model.load_state_dict(pretrained_weights, strict=False)


# Copy weights from the temporary model to the ResNet part of our new model
model.resnet.load_state_dict(temp_model.resnet.state_dict(), strict=False)

# Move the model to the appropriate device
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)  # Lower learning rate for fine-tuning

# Training loop
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
    
    epoch_loss = running_loss / len(train_dataloader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
    
    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f"Validation Accuracy: {accuracy:.2f}%")

print("Training complete!")

# Save the new model
torch.save(model.state_dict(), NEW_MODEL_PATH)

# Function to predict
def predict(model, image_path):
    model.eval()
    image = Image.open(image_path).convert('RGB')
    image_tensor = data_transforms(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(image_tensor)
    _, predicted = torch.max(output, 1)
    return train_dataset.hf_dataset.features['label'].int2str(predicted.item())

# Example usage:
# result = predict(model, '/path/to/test/image.jpg')
# print(f"Predicted food category: {result}")

Using device: cuda


  pretrained_weights = torch.load(MODEL_PATH, map_location=device)


Epoch 1/20, Loss: 1.6183
Validation Accuracy: 57.45%
Epoch 2/20, Loss: 0.2503
Validation Accuracy: 62.91%
Epoch 3/20, Loss: 0.0447
Validation Accuracy: 65.09%
Epoch 4/20, Loss: 0.0215
Validation Accuracy: 64.73%
Epoch 5/20, Loss: 0.0140
Validation Accuracy: 66.91%
Epoch 6/20, Loss: 0.0094
Validation Accuracy: 65.45%
Epoch 7/20, Loss: 0.0071
Validation Accuracy: 66.55%
Epoch 8/20, Loss: 0.0055
Validation Accuracy: 66.91%
Epoch 9/20, Loss: 0.0046
Validation Accuracy: 65.82%
Epoch 10/20, Loss: 0.0040
Validation Accuracy: 67.64%
Epoch 11/20, Loss: 0.0033
Validation Accuracy: 68.00%
Epoch 12/20, Loss: 0.0029
Validation Accuracy: 67.64%
Epoch 13/20, Loss: 0.0033
Validation Accuracy: 68.00%
Epoch 14/20, Loss: 0.0025
Validation Accuracy: 66.55%
Epoch 15/20, Loss: 0.0023
Validation Accuracy: 66.91%
Epoch 16/20, Loss: 0.0019
Validation Accuracy: 67.64%
Epoch 17/20, Loss: 0.0017
Validation Accuracy: 66.55%
Epoch 18/20, Loss: 0.0015
Validation Accuracy: 66.18%
Epoch 19/20, Loss: 0.0014
Validation 

In [7]:
import torch
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
from PIL import Image
import gradio as gr
from datasets import load_from_disk
import os
import traceback
import numpy as np

class FoodClassifier(nn.Module):
    def __init__(self, num_classes):
        super(FoodClassifier, self).__init__()
        self.resnet = models.resnet50(weights=None)
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.resnet(x)

def load_model(model_path, num_classes):
    model = FoodClassifier(num_classes)
    if os.path.exists(model_path):
        try:
            state_dict = torch.load(model_path, map_location=torch.device('cuda'))
            # Load state dict with strict=False to ignore non-matching keys
            model.load_state_dict(state_dict, strict=True)
            
        except Exception as e:
            print(f"Error loading model: {e}")
            print("Initializing with random weights.")
    else:
        print(f"Model file not found at {model_path}. Initializing with random weights.")
    model.eval()
    return model



data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

model_path = 'food _classifier/food_classifier (1).pth'
dataset_path = 'food _classifier/DATA_DIR'

try:
    ds = load_from_disk(dataset_path)
    categories = ds['train'].features['label'].names
    num_classes = len(categories)
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Using default categories...")
    categories = [f"Category_{i}" for i in range(11)]
    num_classes = len(categories)

model = load_model(model_path, num_classes)

def predict_image(image):
    try:
        if image is None:
            return {"Error": "No image provided"}

        if not isinstance(image, np.ndarray):
            image = np.array(image)

        img = Image.fromarray(image.astype('uint8'), 'RGB')
        img_tensor = data_transforms(img).unsqueeze(0)

        with torch.no_grad():
            output = model(img_tensor)

        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        top5_prob, top5_catid = torch.topk(probabilities, min(5, len(categories)))

        results = {categories[top5_catid[i].item()]: float(top5_prob[i]) for i in range(top5_prob.size(0))}

        return results
    except Exception as e:
        print(f"Error in predict_image: {e}")
        print(traceback.format_exc())
        return {"Error": "An error occurred"}

iface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(),
    outputs=gr.Label(num_top_classes=5),
    title="Food Classification",
    description="Upload an image of food to classify it into one of the available categories.",
    examples=[
        "https://images.pexels.com/photos/39803/pexels-photo-39803.jpeg?cs=srgb&dl=apple-fruit-healthy-food-39803.jpg&fm=jpg",
        "https://imgs.search.brave.com/G1O6-4aTqLI2WJ2H5iQjo2tGmFG4HxmL0g7LpBsX6xo/rs:fit:860:0:0:0/g:ce/aHR0cHM6Ly93d3cu/cmVjaXBldGluZWF0/cy5jb20vdGFjaHlv/bi8yMDE5LzA5L0Nv/b2stcmljZS1vbi1z/dG92ZV82LmpwZw",
        "https://tse1.mm.bing.net/th?id=OIP.SvmpgPsIHYoHS08RfnQx1wHaJ6&pid=Api&P=0&h=180",
        "https://imgs.search.brave.com/7zLAVx637-Ij9d_zCXh54wdd8sGovacvW-633wgdt24/rs:fit:860:0:0:0/g:ce/aHR0cHM6Ly9tZWRp/YS5nZXR0eWltYWdl/cy5jb20vaWQvMTM5/MzQwNDg5My9waG90/by90ZXJpeWFraS1z/aHJpbXAtd2l0aC1y/YW1lbi1ub29kbGVz/LmpwZz9zPTYxMng2/MTImdz0wJms9MjAm/Yz05TjNITjdSUlVi/RzVOQlZMd3lJM2ZV/SEJ5WERRVkk2RzNh/QkRVQlFCLVI4PQ",
        "https://imgs.search.brave.com/rWUYUHBxlaQddySHlxGKEoLWBxsp6srd2Y4aCSpN55Q/rs:fit:860:0:0:0/g:ce/aHR0cHM6Ly9tZWRp/YS5pc3RvY2twaG90/by5jb20vaWQvNTMx/NDY0MzY2L3Bob3Rv/L2JlZWYtc3RlYWtz/LW9uLXRoZS1ncmls/bC5qcGc_cz02MTJ4/NjEyJnc9MCZrPTIw/JmM9Z1A1VmlHbkow/OFlUelh0aFNPTUt6/WkVNcXRBYmNNMmpo/RWFvMDlXMWtBUT0"
        

    ]
)

iface.launch(share=True, debug=True)


  state_dict = torch.load(model_path, map_location=torch.device('cuda'))


Error loading model: Error(s) in loading state_dict for FoodClassifier:
	Missing key(s) in state_dict: "resnet.conv1.weight", "resnet.bn1.weight", "resnet.bn1.bias", "resnet.bn1.running_mean", "resnet.bn1.running_var", "resnet.layer1.0.conv1.weight", "resnet.layer1.0.bn1.weight", "resnet.layer1.0.bn1.bias", "resnet.layer1.0.bn1.running_mean", "resnet.layer1.0.bn1.running_var", "resnet.layer1.0.conv2.weight", "resnet.layer1.0.bn2.weight", "resnet.layer1.0.bn2.bias", "resnet.layer1.0.bn2.running_mean", "resnet.layer1.0.bn2.running_var", "resnet.layer1.0.conv3.weight", "resnet.layer1.0.bn3.weight", "resnet.layer1.0.bn3.bias", "resnet.layer1.0.bn3.running_mean", "resnet.layer1.0.bn3.running_var", "resnet.layer1.0.downsample.0.weight", "resnet.layer1.0.downsample.1.weight", "resnet.layer1.0.downsample.1.bias", "resnet.layer1.0.downsample.1.running_mean", "resnet.layer1.0.downsample.1.running_var", "resnet.layer1.1.conv1.weight", "resnet.layer1.1.bn1.weight", "resnet.layer1.1.bn1.bias", "res

Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://dfcf07b8da1455aa02.gradio.live




In [31]:
import torch
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
from PIL import Image
import gradio as gr
from datasets import load_from_disk
import os
import traceback
import numpy as np

class FoodClassifier(nn.Module):
    def __init__(self, num_classes, lstm_hidden_size=256, lstm_num_layers=2):
        super(FoodClassifier, self).__init__()
        self.resnet = models.resnet50(weights=None)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])  # Remove the last FC layer
        
        self.lstm = nn.LSTM(2048, lstm_hidden_size, lstm_num_layers, batch_first=True)
        self.fc = nn.Linear(lstm_hidden_size, num_classes)

    def forward(self, x):
        batch_size, _, _, _ = x.size()
        x = self.resnet(x)
        x = x.view(batch_size, 1, -1)  # Reshape for LSTM input
        lstm_out, _ = self.lstm(x)
        x = self.fc(lstm_out[:, -1, :])  # Use the last LSTM output
        return x

def load_model(model_path, num_classes):
    model = FoodClassifier(num_classes)
    if os.path.exists(model_path):
        try:
            state_dict = torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)
            model.load_state_dict(state_dict, strict=False)
            print("Model loaded successfully.")
        except Exception as e:
            print(f"Error loading model: {e}")
            print("Initializing with random weights.")
    else:
        print(f"Model file not found at {model_path}. Initializing with random weights.")
    model.eval()
    return model

data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

model_path = 'food _classifier/food_classifier_cnn_lstm.pth'
dataset_path = 'food _classifier/DATA_DIR'

try:
    ds = load_from_disk(dataset_path)
    categories = ds['train'].features['label'].names
    num_classes = len(categories)
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Using default categories...")
    categories = [f"Category_{i}" for i in range(11)]
    num_classes = len(categories)

model = load_model(model_path, num_classes)

def predict_image(image):
    try:
        if image is None:
            return {"Error": "No image provided"}
        if not isinstance(image, np.ndarray):
            image = np.array(image)
        img = Image.fromarray(image.astype('uint8'), 'RGB')
        img_tensor = data_transforms(img).unsqueeze(0)
        with torch.no_grad():
            output = model(img_tensor)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        top5_prob, top5_catid = torch.topk(probabilities, min(5, len(categories)))
        results = {categories[top5_catid[i].item()]: float(top5_prob[i]) for i in range(top5_prob.size(0))}
        return results
    except Exception as e:
        print(f"Error in predict_image: {e}")
        print(traceback.format_exc())
        return {"Error": "An error occurred"}

iface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(),
    outputs=gr.Label(num_top_classes=5),
    title="Food Classification",
    description="Upload an image of food to classify it into one of the available categories.",
    examples=[
        "https://images.pexels.com/photos/39803/pexels-photo-39803.jpeg?cs=srgb&dl=apple-fruit-healthy-food-39803.jpg&fm=jpg",
        "https://tse1.mm.bing.net/th?id=OIP.SvmpgPsIHYoHS08RfnQx1wHaJ6&pid=Api&P=0&h=180",
    ]
)

iface.launch(share=True, debug=True)

Model loaded successfully.
* Running on local URL:  http://127.0.0.1:7861
* Running on public URL: https://c30f2937f9e110cce5.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7861 <> https://c30f2937f9e110cce5.gradio.live


