In [114]:
import os

import cv2
import pandas as pd
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.preprocessing import MultiLabelBinarizer
import torchvision

import matplotlib.pyplot as plt

import numpy as np

In [115]:
DATASET_DIR = "../../data/nutrition5k_dataset_nosides/"
PROCESSED_DIR = os.path.join(DATASET_DIR, "processed")

IMG_DIR = "../../data/nutrition5k_dataset_nosides/imagery/realsense_overhead/"

INGREDIENTS_PATH = os.path.join(PROCESSED_DIR, "ingredients_metadata.csv")
DISHES_PATH = os.path.join(PROCESSED_DIR, "dishes_info.csv")

In [116]:
df = pd.read_csv(INGREDIENTS_PATH)
labels = df["ingredient_id"]

label_binarizer = MultiLabelBinarizer()
label_binarizer.fit([labels.to_list()])

num_of_classes = label_binarizer.classes_.shape[0]
print("number of classes:", num_of_classes)
# label_binarizer.transform([[2, 3]])[0]

number of classes: 247


In [117]:
class IngredientDataset(Dataset):
    def __init__(self, img_dir: str, ingredients_path: str, dish_info_path: str, transform=None):
        self.img_dir = img_dir

        self.ing_df = pd.read_csv(ingredients_path)
        self.dish_info_df = pd.read_csv(dish_info_path)

        self.transform = transform

        self.label_binarizer = MultiLabelBinarizer()
        self.label_binarizer.fit([self.ing_df["ingredient_id"].to_list()])

    def __len__(self) -> int:
        return len(self.dish_info_df)

    def __getitem__(self, index):
        dish = self.dish_info_df.iloc[index]
        dish_id = dish[0]
        # print(dish_id)

        ingredient_ids = self.ing_df[self.ing_df["dish_id"] == dish_id]["ingredient_id"].values
        label_encoded = self.label_binarizer.transform([ingredient_ids])[0]
        label_tensor = torch.FloatTensor(label_encoded)
        # print(ingredient_ids)

        dish_weight = dish[2]
        # print(dish_weight)
        weight_in_g_tensor = torch.FloatTensor([dish_weight])

        img_path = os.path.join(self.img_dir, dish_id, "rgb.png")
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = np.transpose(img, (2, 0, 1))
        # plt.imshow(img)
        # plt.show()

        img_tensor = torch.FloatTensor(img)
        img_tensor = self.transform(img_tensor) if self.transform is not None else img_tensor

        return img_tensor, label_tensor, weight_in_g_tensor

    def get_num_of_classes(self) -> int:
        return self.label_binarizer.classes_.shape[0]

In [118]:
def get_accuracy(y_pred: Tensor, y_train: Tensor) -> float:
    y_pred_index = torch.argmax(y_pred, 1)
    y_train_index = torch.argmax(y_train, 1)
    return (y_pred_index == y_train_index).sum().item() / y_pred.shape[0]

In [119]:
def print_statistics(epoch: int, batch: int, num_batches: int, loss: float, acc: float):
    print(f"EPOCH {epoch + 1} | BATCH {batch + 1} of {num_batches} | LOSS {loss:.4f} | ACCURACY {acc:.4f}")

In [120]:
dataset = IngredientDataset(img_dir=IMG_DIR, ingredients_path=INGREDIENTS_PATH, dish_info_path=DISHES_PATH)

training_dataset, validation_dataset = random_split(dataset, [0.7, 0.3])

training_dataloader = DataLoader(training_dataset, batch_size=32, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=32, shuffle=False)

In [129]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

num_classes = dataset.get_num_of_classes()
model = torchvision.models.mobilenet_v3_small(pretrained=True)

num_ftrs = model.classifier[3].in_features

# Change the output layer to match the custom number of classes
model.classifier[3] = torch.nn.Linear(num_ftrs, num_classes)

Downloading: "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth" to /Users/maxburzer/.cache/torch/hub/checkpoints/mobilenet_v3_small-047dcff4.pth
100.0%


In [122]:
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)

# collect stats
train_loss = []
train_acc = []
val_acc = []

In [123]:
device = "cpu"

In [124]:
num_epochs = 1

for epoch in range(num_epochs):
    
    model.train()
    torch.enable_grad()
    print("TRAINING...")

    for index, (X_train, y_train, y_w_train) in enumerate(training_dataloader):
        # move to GPU
        X_train = X_train.to(device)
        y_train = y_train.to(device)

        # forward
        y_pred = model(X_train)
        loss = loss_function(y_pred, y_train)
        acc = get_accuracy(y_pred, y_train)

        # collect stats
        train_loss.append(loss.item())
        train_acc.append(acc)
        print_statistics(epoch, index, len(training_dataloader), loss.item(), acc)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    torch.no_grad()
    print("TESTING...")

    for index, (X_val, y_val, y_w_train) in enumerate(validation_dataloader):
        # move to GPU
        X_val = X_val.to(device)
        y_val = y_val.to(device)

        # forward
        y_pred = model(X_val)
        acc = get_accuracy(y_pred, y_val)

        # collect stats
        val_acc.append(acc)
        print_statistics(epoch, index, len(validation_dataloader), 0, acc)

TRAINING...


  dish_id = dish[0]
  dish_weight = dish[2]


EPOCH 1 | BATCH 1 of 77 | LOSS 53.9290 | ACCURACY 0.0000
EPOCH 1 | BATCH 2 of 77 | LOSS 34.2052 | ACCURACY 0.0000
EPOCH 1 | BATCH 3 of 77 | LOSS 48.3939 | ACCURACY 0.0000
EPOCH 1 | BATCH 4 of 77 | LOSS 43.4006 | ACCURACY 0.0000
EPOCH 1 | BATCH 5 of 77 | LOSS 26.8262 | ACCURACY 0.0312
EPOCH 1 | BATCH 6 of 77 | LOSS 33.3309 | ACCURACY 0.0000
EPOCH 1 | BATCH 7 of 77 | LOSS 27.1515 | ACCURACY 0.0000


KeyboardInterrupt: 