In [10]:
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, random_split, DataLoader
from sklearn.preprocessing import LabelEncoder

In [25]:
# Create images paths
print("Creating image paths list...")
dir_path = 'raw_images'

image_paths = []

for dish_dir in os.listdir(dir_path):
    img_dir = os.path.join(dir_path, dish_dir, 'frames_sampled30')

    if not os.path.exists(img_dir):
        continue
    
    for image in os.listdir(img_dir):
        img_path = os.path.join(img_dir, image)
        image_paths.append(img_path)

import random

# random.seed(42)
#image_paths = random.sample(image_paths, min(10000, len(image_paths)))  
print(len(image_paths))

Creating image paths list...
54845


In [26]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

labels = [os.path.basename(os.path.dirname(os.path.dirname(p))) for p in image_paths]
label_encoder = LabelEncoder()
label_encoder.fit(labels)

class NutritionDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
        labels = [os.path.basename(os.path.dirname(os.path.dirname(p))) for p in image_paths]
        label_encoder = LabelEncoder()
        label_encoder.fit(labels)
        self.encoder = label_encoder

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label_str = os.path.basename(os.path.dirname(os.path.dirname(img_path)))
        label = self.encoder.transform([label_str])[0]

        if self.transform:
            image = self.transform(image)
        else:
            image = transform(image)

        return image, label

In [27]:
input_size = (225, 225)

# mean and std of the entire dataset
data_normals = {
    'mean': [0.5005, 0.4726, 0.3732],
    'std': [0.2193, 0.2296, 0.2398]
}

test_transform = transforms.Compose([
    transforms.Resize(input_size),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=data_normals['mean'],
        std=data_normals['std']
    )
])

In [30]:
import pickle
with open("testimg.pkl", "rb") as f:
    test_imgs = pickle.load(f)

test_set = NutritionDataset(image_paths=test_imgs, transform=test_transform)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

In [31]:

import torch
from sklearn.metrics import classification_report, accuracy_score
from torchvision import models
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm


device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 4772)
model.load_state_dict(torch.load('resnettemp.pth', map_location=device))
model.to(device)

apreds = []
alabels = []

with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        outputs = model(images)
        temp, preds = torch.max(outputs, 1)
        apreds.extend(preds.cpu().numpy())
        alabels.extend(labels.numpy())

print(classification_report(alabels, apreds, digits=4))
print(accuracy_score(alabels, apreds))


100%|██████████| 1200/1200 [09:15<00:00,  2.16it/s]

              precision    recall  f1-score   support

           0     0.9231    0.7500    0.8276        16
           1     1.0000    0.8889    0.9412        18
           2     1.0000    1.0000    1.0000        16
           3     1.0000    1.0000    1.0000        16
           4     1.0000    0.8462    0.9167        13
           5     1.0000    0.8571    0.9231        14
           6     0.8889    1.0000    0.9412        16
           7     0.9000    0.7500    0.8182        12
           8     0.8000    0.5333    0.6400        15
           9     1.0000    1.0000    1.0000        14
          10     0.9091    1.0000    0.9524        10
          11     1.0000    0.9286    0.9630        14
          12     1.0000    0.8667    0.9286        15
          13     1.0000    1.0000    1.0000        15
          14     1.0000    1.0000    1.0000        18
          15     1.0000    0.6667    0.8000        12
          16     0.6923    1.0000    0.8182         9
          17     1.0000    


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
