In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as T
import torchvision.datasets as datasets
import torchvision.models as models

import matplotlib.pyplot as plt

import kaggle
import zipfile
import shutil
import os
import json
from pathlib import Path

In [2]:
path = Path('kmader/food41')

In [3]:
def get_data_transforms():
    train_transforms = T.Compose([T.RandomResizedCrop(224),
                                      T.RandomRotation(35),
                                      T.RandomVerticalFlip(0.27),
                                      T.RandomHorizontalFlip(0.27),
                                      T.ToTensor(),
                                      T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    valid_n_test_transforms = T.Compose([T.Resize((224,224)),
                                       T.ToTensor(),
                                       T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    return train_transforms, valid_n_test_transforms

In [5]:
path = Path('kmader/food41')
data_path = Path('data')
kaggle.api.dataset_download_cli(str(path))
if not data_path.exists():
    os.mkdir(data_path)
zipfile.ZipFile('food41.zip').extractall(data_path)
# with open(data_path/'meta/meta/train.json', 'r') as fp:
#     train_dict = json.load(fp)
# with open(data_path/'meta/meta/test.json', 'r') as fp:
#     test_dict = json.load(fp)

Downloading food41.zip to /root/Nutri-2.0


100%|██████████| 5.30G/5.30G [04:25<00:00, 21.5MB/s]  





In [None]:
def download_food101_kaggle(batch_size=32, download=True):
    """
    The function to download the Food-101 dataset and split into train, validation, and test sets.
    """
    # Make sure you have your API token file at the correct folder
    # Follow https://www.kaggle.com/docs/api#authentication
    if download:
        path = Path('kmader/food41')
        kaggle.api.dataset_download_cli(str(path))
        data_path = Path('data')
        if not data_path.exists():
            os.mkdir(data_path)
        zipfile.ZipFile('food41.zip').extractall(data_path)
        with open(data_path/'meta/meta/train.json', 'r') as fp:
            train_dict = json.load(fp)
        with open(data_path/'meta/meta/test.json', 'r') as fp:
            test_dict = json.load(fp)
        original_data_path = Path('food41/images')
        new_folders = ['train', 'test']
        for folder in new_folders:
            if not os.path.exists(new_data_path/folder):
                os.mkdir(new_data_path/folder)
            if folder == 'train':
                if not os.path.exists(new_data_path/'valid'):
                    os.mkdir(new_data_path/'valid')
                for key, value in train_dict.items():
                    train_value, valid_value = train_test_split(value, train_size=0.75)
                    train_set, valid_set = set(train_value), set(valid_value)
                    if not os.path.exists(new_data_path/folder/key):
                        os.mkdir(new_data_path/folder/key)
                    if not os.path.exists(new_data_path/'valid'/key):
                        os.mkdir(new_data_path/'valid'/key)
                    for image in os.listdir(original_data_path/key):
                        image_path = key + '/' + image
                        image_path = image_path.split('.')[0]
                        if image_path in train_set:
                            shutil.copy(original_data_path/key/image, new_data_path/'train'/key/image)
                        if image_path in valid_set:
                            shutil.copy(original_data_path/key/image, new_data_path/'valid'/key/image)
            else:
                for key, value in test_dict.items():
                    value_set = set(value)
                    if not os.path.exists(new_data_path/folder/key):
                        os.mkdir(new_data_path/folder/key)
                    for image in os.listdir(original_data_path/key):
                        image_path = key + '/' + image
                        image_path = image_path.split('.')[0]
                        if image_path in value_set:
                            shutil.copy(original_data_path/key/image, new_data_path/folder/key/image)
        shutil.rmtree(original_data_path)
    new_data_path = Path('data')
    train_dir = new_data_path/'train'
    valid_dir = new_data_path/'valid'
    test_dir = new_data_path/'test'
    train_transforms, valid_n_test_transforms = get_data_transforms()
    train_dataset = datasets.ImageFolder(train_dir, transform = train_transforms)
    valid_dataset = datasets.ImageFolder(valid_dir, transform = valid_n_test_transforms)
    test_dataset = datasets.ImageFolder(test_dir, transform = valid_n_test_transforms)
    class_names = test_dataset.classes
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
    return train_loader, valid_loader, test_loader, test_dataset, class_names