# Vietnamese Food Recognizer

## 1. Dataset
- This dataset was derived from this kaggle dataset: [Link here](https://www.kaggle.com/datasets/quandang/vietnamese-foods/data)

## Table of content

### 1. [Crawling data for 3 different dishes](#1)

### 2. [Data prepreration](#2)

### 3. [Training model](#3)

### 4. [Transfer learning with MobileNet to recognize more foods](#4-transfer-learning-with-mobilenet-to-recognize-more-foods)

In [26]:
import ssl

ssl._create_default_https_context = ssl._create_unverified_context
import numpy as np
import torch
import torch.nn as nn
import requests
from io import BytesIO
import asyncio
import cv2
from PIL import Image
import os
from time import time
from tqdm import tqdm


<a id="1"></a> 

### 1. Crawling data for 3 different dishes 



There are more than 30 Vietnamese dishes in the dataset. However, in this notebook, I will only pick 3 of them which includes: **Banh mi, Pho, Mi Quang** for the purpose of this challenge.

We can actually download the existed images for each food. In this notebook, I will try to practice writing some asynchronous code to test the performance between processing synchronously and asynchronously. Therefore, I decided to use only raw text file endpoint for each food.

In [13]:
import httpx

ROOT_DIR = os.path.abspath(".")
ROOT_DIR

'/Users/mac/Desktop/Code/Personal_Project/VNUK/VN_Food_Recognizer'

In [None]:

print("Start setting up data")
t1 = time()
for food_name, food_data_url in urls.items():
    print(f"Crawling data for: {food_name}")
    response = requests.get(food_data_url)
    food_urls = response.content.decode("utf-8").split("\n")

    dataset_dir = os.path.join(ROOT_DIR, 'dataset')
    if not os.path.exists(dataset_dir):
        os.mkdir(dataset_dir)

    food_dir = os.path.join(dataset_dir, food_name)
    if not os.path.exists(food_dir):
        os.mkdir(food_dir)
    
    counter = 1
    for food_url in tqdm(food_urls, desc="Image"):
        try:
            food_response = requests.get(food_url, timeout=10)
            img = Image.open(BytesIO(food_response.content))
            img_dir = os.path.join(food_dir, f"image_{counter}.png")
            img.save(img_dir, "PNG")
            counter += 1
        except Exception as e:
            print(e)
            pass
    print(f"Finish downloading images for {food_name}")
t2 = time()
print(f"DONE! Total time taken for synchronous processing: {t2 - t1}s")

Start setting up data
Crawling data for: Banh mi


Image:   0%|          | 3/1336 [00:30<5:09:36, 13.94s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g69/683391/prof/s640x400/foody-mobile-hmkh-jpg-443-636392742277250874.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bd2aad0>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   0%|          | 4/1336 [01:00<7:30:14, 20.28s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g104/1030781/prof/s640x400/foody-upload-api-foody-mobile-foody-upload-api-foo-200619091742.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bd27ad0>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   0%|          | 5/1336 [01:30<8:48:03, 23.80s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g90/899822/prof/s1242x600/foody-upload-api-foody-mobile-4-190627140327.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bd309d0>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   0%|          | 6/1336 [02:00<9:34:24, 25.91s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g88/872371/prof/s640x400/foody-upload-api-foody-mobile-bmi-6-jpg-181227092408.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bd33550>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   1%|          | 7/1336 [02:02<6:40:13, 18.07s/it]

cannot identify image file <_io.BytesIO object at 0x12ac7bb00>


Image:   1%|          | 8/1336 [02:32<8:04:03, 21.87s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g88/872371/prof/s576x330/foody-upload-api-foody-mobile-bmi-6-jpg-181227092408.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16b8c39d0>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))
HTTPSConnectionPool(host='znews-photo.zadn.vn', port=443): Max retries exceeded with url: /w660/Uploaded/zngure/2020_08_12/mysteriousaigon_1.jpg (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x16b762d10>: Failed to resolve 'znews-photo.zadn.vn' ([Errno 8] nodename nor servname provided, or not known)"))


Image:   1%|          | 11/1336 [02:34<3:14:59,  8.83s/it]

cannot identify image file <_io.BytesIO object at 0x12b12e6b0>
HTTPSConnectionPool(host='scontent-hkt1-1.cdninstagram.com', port=443): Max retries exceeded with url: /v/t51.2885-15/e35/p1080x1080/119713590_353109792492903_2219476863490742484_n.jpg?_nc_ht=scontent-hkt1-1.cdninstagram.com&_nc_cat=103&_nc_ohc=0_dSKJy_n7MAX9E8Qtz&tp=19&oh=81960da3f1e762850a3518c8b79e391b&oe=5FB2E805 (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x16b90b150>: Failed to resolve 'scontent-hkt1-1.cdninstagram.com' ([Errno 8] nodename nor servname provided, or not known)"))


Image:   1%|          | 13/1336 [03:04<4:11:56, 11.43s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g68/677885/prof/s576x330/foody-mobile-hmb-jpg-524-636371141011785093.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16b81f0d0>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   1%|▏         | 17/1336 [03:07<1:29:09,  4.06s/it]

cannot identify image file <_io.BytesIO object at 0x12b16fba0>


Image:   1%|▏         | 19/1336 [03:08<50:52,  2.32s/it]  

cannot identify image file <_io.BytesIO object at 0x129981c60>
HTTPSConnectionPool(host='scontent-hkt1-1.cdninstagram.com', port=443): Max retries exceeded with url: /v/t51.2885-15/e35/s1080x1080/120140900_613142515997916_4635415578831372827_n.jpg?_nc_ht=scontent-hkt1-1.cdninstagram.com&_nc_cat=100&_nc_ohc=tiW2NZbj1igAX_GOIEA&_nc_tp=15&oh=b248b79a0f627d529e23eeaee8efab32&oe=5FB37598 (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x12b13c5d0>: Failed to resolve 'scontent-hkt1-1.cdninstagram.com' ([Errno 8] nodename nor servname provided, or not known)"))


Image:   2%|▏         | 23/1336 [03:39<3:10:05,  8.69s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g24/236013/prof/s640x400/foody-mobile-foody-banh-mi-thit-x-402-635989911554020629.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16b7f5010>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   2%|▏         | 24/1336 [03:41<2:30:05,  6.86s/it]

cannot identify image file <_io.BytesIO object at 0x16bd45710>


Image:   2%|▏         | 29/1336 [04:15<3:50:13, 10.57s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g4/38046/prof/s576x330/foody-mobile-banh-mi-ba-le-le-thanh-phuong-khanh-hoa.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bb97550>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   2%|▏         | 30/1336 [04:25<3:47:05, 10.43s/it]

HTTPConnectionPool(host='www.dulanotes.com', port=80): Read timed out. (read timeout=10)


Image:   3%|▎         | 34/1336 [05:21<6:04:01, 16.78s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g65/645599/prof/s576x330/foody-mobile-c2-jpg-885-636258551380578147.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16b6b9dd0>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   3%|▎         | 35/1336 [05:51<7:29:38, 20.74s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g75/748680/prof/s576x330/foody-upload-api-foody-mobile-hhh-jpg-180608110011.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bb05d90>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   3%|▎         | 37/1336 [06:21<6:56:54, 19.26s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g4/32034/prof/s576x330/foody-mobile-my-muoi-lo-banh-mi-ca-mau.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16b829d90>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))
HTTPSConnectionPool(host='scontent-hkt1-1.cdninstagram.com', port=443): Max retries exceeded with url: /v/t51.2885-15/e35/s1080x1080/120041781_880704555792515_7930905104003237800_n.jpg?_nc_ht=scontent-hkt1-1.cdninstagram.com&_nc_cat=111&_nc_ohc=94HqZOezxVoAX9-Pk8k&_nc_tp=15&oh=3650afc268b426ad49f198062c752927&oe=5FB32415 (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x16b909b50>: Failed to resolve 'scontent-hkt1-1.cdninstagram.com' ([Errno 8] nodename nor servname provided, or not known)"))


Image:   3%|▎         | 39/1336 [06:51<6:13:55, 17.30s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g32/311945/prof/s576x330/foody-mobile-hmbs-jpg-525-636173302436962410.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16b7f55d0>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   3%|▎         | 41/1336 [07:22<5:27:58, 15.20s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g16/155251/prof/s576x330/foody-mobile-banh-mi-jpg-706-635966604962217805.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16b81c6d0>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))
cannot identify image file <_io.BytesIO object at 0x128ba2110>


Image:   3%|▎         | 42/1336 [07:24<4:14:18, 11.79s/it]

cannot identify image file <_io.BytesIO object at 0x12863da30>


Image:   3%|▎         | 43/1336 [07:54<6:03:57, 16.89s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g69/689992/prof/s576x330/foody-mobile-img_1104-jpg-400-636416759295645069.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bbb11d0>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   3%|▎         | 45/1336 [08:25<5:16:31, 14.71s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g22/219254/prof/s576x330/foody-upload-api-foody-mobile-5-190513120814.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bbb2cd0>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))
cannot identify image file <_io.BytesIO object at 0x12b16fba0>


Image:   4%|▎         | 47/1336 [08:55<4:51:39, 13.58s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g21/206607/prof/s576x330/foody-mobile-32134-jpg-252-635889635685107785.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16b7c4690>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))
HTTPSConnectionPool(host='imagevietnam.vnanet.vn', port=443): Max retries exceeded with url: /upload/Thumnail/2020/1/14/14012020105222243CV6.jpg (Caused by SSLError(SSLError(1, '[SSL: UNSAFE_LEGACY_RENEGOTIATION_DISABLED] unsafe legacy renegotiation disabled (_ssl.c:1006)')))
HTTPSConnectionPool(host='vimoc.com.vn', port=443): Max retries exceeded with url: /image/catalog/banh-my-hoi-an-banh-mi-lanh.jpg (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x16bb35150>: Failed to resolve 'vimoc.com.vn' ([Errno 8] nodename nor servname provided, or not known)"))
HTTPSConnectionPool(host='scontent-hkt1-1.cdninstagram.com', port=443): Max retries e

Image:   4%|▍         | 53/1336 [09:07<1:40:30,  4.70s/it]

HTTPSConnectionPool(host='images1.houstonpress.com', port=443): Max retries exceeded with url: /imager/u/original/11043852/roostar-banh-mi-balke.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bbadb90>, 'Connection to images1.houstonpress.com timed out. (connect timeout=10)'))
cannot identify image file <_io.BytesIO object at 0x12b5a31f0>


Image:   4%|▍         | 54/1336 [09:08<1:22:00,  3.84s/it]

cannot identify image file <_io.BytesIO object at 0x11ff28680>


Image:   4%|▍         | 55/1336 [09:09<1:02:44,  2.94s/it]

cannot identify image file <_io.BytesIO object at 0x12863da30>


Image:   4%|▍         | 57/1336 [09:19<1:16:46,  3.60s/it]

HTTPSConnectionPool(host='www.citypassguide.com', port=443): Read timed out. (read timeout=10)
cannot identify image file <_io.BytesIO object at 0x128ba2110>


Image:   4%|▍         | 58/1336 [09:49<4:00:01, 11.27s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g19/184320/prof/s1242x600/foody-mobile-15745476231_0a512e8a-477-635827491386862536.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bd268d0>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   4%|▍         | 59/1336 [10:19<5:56:37, 16.76s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g87/869326/prof/s576x330/foody-upload-api-foody-mobile-22a-190116081206.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16b811510>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   5%|▍         | 61/1336 [10:54<6:29:51, 18.35s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g98/978722/prof/s576x330/foody-upload-api-foody-mobile-hm1-191121145803.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bbafa10>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   5%|▍         | 64/1336 [10:55<2:18:35,  6.54s/it]

HTTPSConnectionPool(host='static.vietnammm.com', port=443): Max retries exceeded with url: /images/restaurants/vn/O1105R01/products/combo-tra-banh.png (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate (_ssl.c:1006)')))
HTTPSConnectionPool(host='tgtt.onecmscdn.com', port=443): Max retries exceeded with url: /2020/03/22/banh_my.jpg (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x16b8c23d0>: Failed to resolve 'tgtt.onecmscdn.com' ([Errno 8] nodename nor servname provided, or not known)"))


Image:   5%|▍         | 65/1336 [11:01<2:14:57,  6.37s/it]

cannot identify image file <_io.BytesIO object at 0x128bd52b0>


Image:   5%|▌         | 68/1336 [11:03<54:13,  2.57s/it]  

cannot identify image file <_io.BytesIO object at 0x12864c270>


Image:   5%|▌         | 69/1336 [11:33<3:47:54, 10.79s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g32/311228/prof/s576x330/foody-mobile-bm-m-jpg-360-636171400037307985.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16b812950>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))
HTTPSConnectionPool(host='scontent-hkt1-1.cdninstagram.com', port=443): Max retries exceeded with url: /v/t51.2885-15/e35/s1080x1080/119681348_106517597872630_6965817227824058991_n.jpg?_nc_ht=scontent-hkt1-1.cdninstagram.com&_nc_cat=110&_nc_ohc=K3sH7m-cAzoAX-gL1mP&_nc_tp=15&oh=c8878c7cba502c0bd713026053fdd43c&oe=5FB2050E (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x16bbac410>: Failed to resolve 'scontent-hkt1-1.cdninstagram.com' ([Errno 8] nodename nor servname provided, or not known)"))


Image:   5%|▌         | 71/1336 [11:34<2:07:11,  6.03s/it]

cannot identify image file <_io.BytesIO object at 0x12aaa6480>


Image:   5%|▌         | 73/1336 [12:05<3:58:00, 11.31s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g96/954103/prof/s1242x600/foody-upload-api-foody-mobile-bm-190906164742.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bbb38d0>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   6%|▌         | 74/1336 [12:35<5:45:02, 16.40s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g66/652495/prof/s640x400/foody-mobile-2-jpg-798-636281219576137094.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bd30150>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   6%|▌         | 75/1336 [12:35<4:11:47, 11.98s/it]

cannot identify image file <_io.BytesIO object at 0x12b12f510>


Image:   6%|▌         | 76/1336 [13:05<5:59:41, 17.13s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g96/954103/prof/s640x400/foody-upload-api-foody-mobile-bm-190906164742.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16b82a350>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   6%|▌         | 78/1336 [13:36<6:08:42, 17.59s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g96/956909/prof/s1242x600/foody-upload-api-foody-mobile-sh-190916111047.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16b81fd10>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   6%|▌         | 79/1336 [14:06<7:25:10, 21.25s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g97/964306/prof/s576x330/foody-upload-api-foody-mobile-1-191002151618.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bbad010>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))
HTTPSConnectionPool(host='znews-photo.zadn.vn', port=443): Max retries exceeded with url: /w660/Uploaded/kbd_bcvi/2019_07_31/1.jpg (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x16bbb3ed0>: Failed to resolve 'znews-photo.zadn.vn' ([Errno 8] nodename nor servname provided, or not known)"))


Image:   6%|▌         | 81/1336 [14:36<6:25:18, 18.42s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g64/637502/prof/s576x330/foody-mobile-hmbbd-jpg-957-636238147011323491.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bbe0950>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   6%|▌         | 82/1336 [15:06<7:24:37, 21.27s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g81/803351/prof/s576x330/foody-upload-api-foody-mobile-bmi-6-jpg-181119164447.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bbd8550>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   6%|▌         | 83/1336 [15:37<8:11:58, 23.56s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g95/946419/prof/s576x330/foody-upload-api-foody-mobile-avar22-190807160257.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bbf6a10>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   6%|▋         | 85/1336 [15:38<4:25:29, 12.73s/it]

HTTPSConnectionPool(host='cdn.usarestaurants.info', port=443): Max retries exceeded with url: /assets/uploads/45ebf7dadd9373fb73a9f18708459e4a_-united-states-california-los-angeles-county-rosemead-343508-banh-mi-my-tho-vietnamese-sandwichhtm.jpg (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x16b8c3a50>: Failed to resolve 'cdn.usarestaurants.info' ([Errno 8] nodename nor servname provided, or not known)"))


Image:   7%|▋         | 89/1336 [15:41<1:15:03,  3.61s/it]

cannot identify image file <_io.BytesIO object at 0x12b5a2610>
cannot identify image file <_io.BytesIO object at 0x12a9eaca0>


Image:   7%|▋         | 90/1336 [15:41<54:10,  2.61s/it]  

cannot identify image file <_io.BytesIO object at 0x11ff28680>


Image:   7%|▋         | 91/1336 [16:11<3:43:12, 10.76s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g63/629911/prof/s576x330/foody-mobile-16832276_58657612819-820-636234451741823805.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bbb34d0>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   7%|▋         | 93/1336 [16:14<2:02:11,  5.90s/it]

cannot identify image file <_io.BytesIO object at 0x12864c270>


Image:   7%|▋         | 94/1336 [16:14<1:27:16,  4.22s/it]

cannot identify image file <_io.BytesIO object at 0x12b12f510>


Image:   7%|▋         | 96/1336 [16:15<47:14,  2.29s/it]  

cannot identify image file <_io.BytesIO object at 0x16b7ced90>


Image:   7%|▋         | 97/1336 [16:45<3:39:01, 10.61s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g80/792208/prof/s640x400/foody-upload-api-foody-mobile-hmn265-jpg-181102180502.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bbb2dd0>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   7%|▋         | 100/1336 [16:46<1:22:50,  4.02s/it]

cannot identify image file <_io.BytesIO object at 0x12864c270>


Image:   8%|▊         | 101/1336 [17:16<4:03:11, 11.81s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /brand/s1170x300/foody-tuan-map-635128683446601250.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16bd28890>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   8%|▊         | 105/1336 [17:20<1:13:04,  3.56s/it]

cannot identify image file <_io.BytesIO object at 0x128b7bd30>


Image:   8%|▊         | 106/1336 [17:50<3:56:38, 11.54s/it]

HTTPSConnectionPool(host='images.foody.vn', port=443): Max retries exceeded with url: /res/g9/84062/prof/s576x330/foody-mobile-banh-my-pate-doi-can-ha-noi-140627111218.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x16b761250>, 'Connection to images.foody.vn timed out. (connect timeout=10)'))


Image:   8%|▊         | 106/1336 [17:57<3:28:21, 10.16s/it]


KeyboardInterrupt: 

In [None]:
# Writing new function for crawling data asynchronously

urls = {
    "Banh mi": "https://storage.googleapis.com/kagglesdsdata/datasets/1050510/2399712/Urls/Banh%20mi.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20251018%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20251018T134310Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=68d019c534a2be40b1e553ddf34b8c68da95c5529e109c4acf696261b7c224d010765da96cdcc88d43dc9b4d8e26b427e706c144c9ce43f8a076feda1ac35db2645ec63c7fef5b80b9c91f45738598ea99642ccc2a1e3e33b30bcc094574cb172f38f15581e7fe6ca440074471405fe9a70be3d7755d4a5af031c30b940ca0fe2521357b8988c35efe2a935c5141a3aef4051fce851a27ac1d4506b4160ab608eed939e221dcc3c92699d90cab41006360f3d13b7125bd267ead9c813029e539a3a61119bad08473a7d1cdae85d28bbb562b1f0183cf887e1c9934f820460eb7ee81f7c8d1c638fe316ca2dd235a76ca3010b09a424afc21e45934abd639547c",
    "Pho": "https://storage.googleapis.com/kagglesdsdata/datasets/1050510/2399712/Urls/Pho.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20251018%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20251018T134427Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=a3a05ed4aaf09ba2ea9156826617d06f64c1415e88f1c34ce5b8197f47ec3d9b238a1854a2ec43127d124e09e3b7e6bbb06b7bf625f6c681dda7aa982174777ef13120ea10e2544638812f5e1218c4f816d547567d2affc5cd1a14380958b9302a07976412daadc04082fa405bf1c122a07d038f64e31827213248d290fd8049f8399d7864a09654b3604bffbdfe52000d8ef2e461c85371fcd22e02ff036579d8652a96066da7e5d3fa145e06fe90480299543aac0885b1c0039880690da1c0a9ee597e28ed8450ddfa770435ff13b78532efa263de39b74dd7b5ed6f292d6167c7fde366f0c48ca4f762f97758e07e2cad651bf44117eb999081fcac63ace7",
    "Mi Quang": "https://storage.googleapis.com/kagglesdsdata/datasets/1050510/2399712/Urls/Mi%20Quang.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20251018%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20251018T134506Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=4d505a4cde97d142497238b475b1785f5b5f717e7a56f68f37c436c1459d0348285e1f18b7ae05ad132cd6d162115c4ab3eaa048477bd1a1c6d64af0140dce14bbb8f297ade16f413d8f056e59d6018e2a20ad0d44c53ce60b1120d31c640afd4b9095a42f041bf8a61003efb45af6b5b9aecf8adcc907c474b55ba87095d5ffc2fcf98c200301ab6469aaf983bfcf6781c90731d02f42cf2a50292ce2d890f90ccc3a3805fab56960d1a29d642c4a298a85377f446eae4b215067bf1643d678a6d183d5adfa555c7f76a7e659aadb59a142f48a6bdda2464ee1e65bb5804e270a54b675e3f1ce412c4d03013d1c9b7a6f46721a81220c779ca36ce2d72e8495",
}
async def get_food_urls(endpoint):
    async with httpx.AsyncClient(timeout = 10) as client:
        response = await client.get(endpoint)
        return response.content.decode("utf-8").split("\n")

async def write_images_to_disk(food_name, food_urls):
    dataset_dir = os.path.join(ROOT_DIR, "dataset")
    if not os.path.exists(dataset_dir):
        os.mkdir(dataset_dir)

    food_dir = os.path.join(dataset_dir, food_name)
    if not os.path.exists(food_dir):
        os.mkdir(food_dir)

    save_image_task = [save_image(food_url, f"image_{i}.png",food_dir) for i, food_url in enumerate(food_urls, start = 1)]
    
    results = await asyncio.gather(*save_image_task)
    
    print(f"Finish downloading images for {food_name}")

async def save_image(food_url, img_name, food_dir):
    try:
        async with httpx.AsyncClient(timeout = 20) as client:
            food_response = await client.get(food_url)
            img_bytes = BytesIO(food_response.content)
            img_bytes.seek(0)
            img = Image.open(img_bytes)
            img.load()
            img_dir = os.path.join(food_dir, img_name)
            img.save(img_dir, "PNG")
    except Exception as e:
        print("error occur with url: ", food_url)
        print(e)


async def download_food_image(food_name, endpoint):
    food_urls = await get_food_urls(endpoint)
    await write_images_to_disk(food_name, food_urls)

print("Start setting up data")
t1 = time()
food_image_task = [download_food_image(food_name, endpoint) for food_name, endpoint in urls.items()]
await asyncio.gather(*food_image_task)
t2 = time()
print(f"DONE! Total time taken for asynchronous processing: {t2 - t1}s")

Start setting up data
error occur with url:  https://scontent-hkt1-1.cdninstagram.com/v/t51.2885-15/e35/s1080x1080/120138700_150298896733920_9091251351111543872_n.jpg?_nc_ht=scontent-hkt1-1.cdninstagram.com&_nc_cat=100&_nc_ohc=S3dodtxl8roAX9HaGez&_nc_tp=15&oh=64397ecf3d10e680d43c6c6c9c529b17&oe=5FB2AD6F
[Errno 8] nodename nor servname provided, or not known
error occur with url:  https://scontent-hkt1-1.cdninstagram.com/v/t51.2885-15/e35/s1080x1080/118702287_3312335992207719_4848107457462024764_n.jpg?_nc_ht=scontent-hkt1-1.cdninstagram.com&_nc_cat=103&_nc_ohc=QG8MT9kfIJkAX9dpSVi&_nc_tp=15&oh=3385dbeffece5fce12b65ad30be61709&oe=5FB39709
[Errno 8] nodename nor servname provided, or not known
error occur with url:  https://scontent-hkt1-1.cdninstagram.com/v/t51.2885-15/e35/s1080x1080/121121998_3673195146065754_1501120690075760719_n.jpg?_nc_ht=scontent-hkt1-1.cdninstagram.com&_nc_cat=106&_nc_ohc=nk_OsmAdFOAAX_fB6Be&_nc_tp=15&oh=beb5d7a92665391f33e0f46632a01ad3&oe=5FB39857
[Errno 8] nodenam

<a id="2"></a> 

### 2. Data preparation

In [50]:
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms

NUM_CLASSES = 3
LABELS = ['Banh mi', 'Mi Quang', 'Pho']
class FoodDataset(Dataset):
    def __init__(self, data_dir, transform = None, shuffle = True):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transform
        self.data = []
        for food_name in os.listdir(self.data_dir):
            food_dir = os.path.join(self.data_dir, food_name)
            for file in os.listdir(food_dir):
                if "png" not in file: continue
                img_dir = os.path.join(food_dir, file)
                label = LABELS.index(food_name)
                self.data.append([img_dir,label])
                
        if shuffle:
            # Shuffle data
            shuffle_indices = np.random.choice(np.arange(len(self.data)), size = len(self.data), replace=False)
            self.data = np.array(self.data)[shuffle_indices]
                
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        img_dir, label = self.data[index]
        img = Image.open(img_dir).convert("RGB")
        
        if self.transform:
            img = self.transform(img)
        
        y = torch.tensor(int(label))
        return img, y

DATA_DIR = os.path.join(ROOT_DIR, "dataset")     
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
dataset = FoodDataset(DATA_DIR,transform)
train_ds, val_ds = random_split(dataset, [0.8, 0.2])

BATCH_SIZE = 16
train_loader = DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(dataset=val_ds, batch_size=BATCH_SIZE)

<a id="3"></a> 

### 3. Training model

In [51]:
# Finetune MobileV2Net
class FoodCNN(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sequential = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(in_features=64, out_features=32),
            nn.ReLU(),
            nn.Linear(32, len(LABELS)),
        )

    def forward(self, X):
        return self.sequential(X)

In [52]:
def get_label(logits):
    return nn.Softmax(dim=1)(logits).argmax(dim=1).cpu().detach().numpy().tolist()

In [None]:
from sklearn.metrics import accuracy_score
from torchsummary import summary
device = 'mps'
model = FoodCNN()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
criteria = nn.CrossEntropyLoss()

summary(model, input_size=(3, 224, 224))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 222, 222]             896
              ReLU-2         [-1, 32, 222, 222]               0
            Conv2d-3         [-1, 32, 220, 220]           9,248
              ReLU-4         [-1, 32, 220, 220]               0
         MaxPool2d-5         [-1, 32, 110, 110]               0
            Conv2d-6         [-1, 64, 108, 108]          18,496
              ReLU-7         [-1, 64, 108, 108]               0
            Conv2d-8        [-1, 128, 106, 106]          73,856
              ReLU-9        [-1, 128, 106, 106]               0
        MaxPool2d-10          [-1, 128, 53, 53]               0
AdaptiveAvgPool2d-11            [-1, 128, 1, 1]               0
          Flatten-12                  [-1, 128]               0
           Linear-13                   [-1, 64]           8,256
             ReLU-14                   

In [54]:


# Training
model = model.to(device=device)
EPOCHS = 50
for epoch in tqdm(range(EPOCHS), desc="Epoch"):
    model.train()
    total_train_loss = 0
    y_true_train, y_pred_train = [], []
    for inputs_train, labels_train in train_loader:
        inputs_train = inputs_train.to(device=device)
        labels_train = labels_train.to(device=device)
        train_logits = model(inputs_train)
        train_loss = criteria(train_logits, labels_train)
        total_train_loss += train_loss.item()
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        preds_train = get_label(train_logits)
        y_true_train.extend(labels_train.cpu().detach().numpy().tolist())
        y_pred_train.extend(preds_train)

    avg_train_loss = total_train_loss / len(train_loader)
    train_acc = accuracy_score(y_true_train, y_pred_train)
    model.eval()
    total_val_loss = 0
    y_true_val, y_pred_val = [], []
    with torch.inference_mode():
        for inputs_val, labels_val in val_loader:
            inputs_val = inputs_val.to(device=device)
            labels_val = labels_val.to(device=device)
            val_logits = model(inputs_val)
            val_loss = criteria(val_logits, labels_val)
            total_val_loss += val_loss.item()
            preds_val = get_label(val_logits)
            y_true_val.extend(labels_val.cpu().detach().numpy().tolist())
            y_pred_val.extend(preds_val)

    avg_val_loss = total_val_loss / len(val_loader)
    val_acc = accuracy_score(y_true_val, y_pred_val)
    if epoch % 10 == 0:
        print(
            f"Epoch [{epoch}|{EPOCHS}] \n Training loss: {round(avg_train_loss, 5)} - Train acc: {round(train_acc,5)}. \n Validation loss: {round(avg_val_loss, 5)} - Val acc: {round(val_acc, 5)} "
        )

Epoch:   2%|▏         | 1/50 [00:17<14:16, 17.48s/it]

Epoch [0|50] 
 Training loss: 1.07881 - Train acc: 0.4375. 
 Validation loss: 1.10291 - Val acc: 0.39024 


Epoch:  22%|██▏       | 11/50 [03:11<11:16, 17.34s/it]

Epoch [10|50] 
 Training loss: 0.85899 - Train acc: 0.62097. 
 Validation loss: 0.91084 - Val acc: 0.56098 


Epoch:  42%|████▏     | 21/50 [06:02<08:15, 17.08s/it]

Epoch [20|50] 
 Training loss: 0.70764 - Train acc: 0.67944. 
 Validation loss: 0.77219 - Val acc: 0.65854 


Epoch:  62%|██████▏   | 31/50 [08:56<05:33, 17.57s/it]

Epoch [30|50] 
 Training loss: 0.5376 - Train acc: 0.76411. 
 Validation loss: 0.70402 - Val acc: 0.69106 


Epoch:  82%|████████▏ | 41/50 [11:50<02:35, 17.27s/it]

Epoch [40|50] 
 Training loss: 0.4304 - Train acc: 0.81653. 
 Validation loss: 0.65343 - Val acc: 0.72358 


Epoch: 100%|██████████| 50/50 [14:26<00:00, 17.33s/it]


In [55]:
y_true_val, y_pred_val = [], []
with torch.inference_mode():
    for inputs_val, labels_val in tqdm(val_loader, desc="Validation"):
        inputs_val = inputs_val.to(device=device)
        labels_val = labels_val.to(device=device)
        val_logits = model(inputs_val)
        val_loss = criteria(val_logits, labels_val)
        total_val_loss += val_loss.item()
        preds_val = get_label(val_logits)
        y_true_val.extend(labels_val.cpu().detach().numpy().tolist())
        y_pred_val.extend(preds_val)

avg_val_loss = total_val_loss / len(val_loader)
val_acc = accuracy_score(y_true_val, y_pred_val)
print(val_acc)

Validation: 100%|██████████| 8/8 [00:02<00:00,  2.72it/s]

0.7967479674796748





In [57]:
PATH = "model_weights.pth"
torch.save(model.state_dict(), PATH)

In [59]:
# Continue training for 30 more epochs
for epoch in tqdm(range(EPOCHS, EPOCHS + 30), desc="Epoch"):
    model.train()
    total_train_loss = 0
    y_true_train, y_pred_train = [], []
    for inputs_train, labels_train in train_loader:
        inputs_train = inputs_train.to(device=device)
        labels_train = labels_train.to(device=device)
        train_logits = model(inputs_train)
        train_loss = criteria(train_logits, labels_train)
        total_train_loss += train_loss.item()
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        preds_train = get_label(train_logits)
        y_true_train.extend(labels_train.cpu().detach().numpy().tolist())
        y_pred_train.extend(preds_train)

    avg_train_loss = total_train_loss / len(train_loader)
    train_acc = accuracy_score(y_true_train, y_pred_train)
    model.eval()
    total_val_loss = 0
    y_true_val, y_pred_val = [], []
    with torch.inference_mode():
        for inputs_val, labels_val in val_loader:
            inputs_val = inputs_val.to(device=device)
            labels_val = labels_val.to(device=device)
            val_logits = model(inputs_val)
            val_loss = criteria(val_logits, labels_val)
            total_val_loss += val_loss.item()
            preds_val = get_label(val_logits)
            y_true_val.extend(labels_val.cpu().detach().numpy().tolist())
            y_pred_val.extend(preds_val)

    avg_val_loss = total_val_loss / len(val_loader)
    val_acc = accuracy_score(y_true_val, y_pred_val)
    if epoch % 10 == 0:
        print(
            f"Epoch [{epoch+ 1}|{EPOCHS + 30}] \n Training loss: {round(avg_train_loss, 5)} - Train acc: {round(train_acc,5)}. \n Validation loss: {round(avg_val_loss, 5)} - Val acc: {round(val_acc, 5)} "
        )

Epoch:   3%|▎         | 1/30 [00:17<08:16, 17.11s/it]

Epoch [51|80] 
 Training loss: 0.34349 - Train acc: 0.85685. 
 Validation loss: 0.64527 - Val acc: 0.73984 


Epoch:  37%|███▋      | 11/30 [03:11<05:32, 17.52s/it]

Epoch [61|80] 
 Training loss: 0.30184 - Train acc: 0.89113. 
 Validation loss: 0.69288 - Val acc: 0.73984 


Epoch:  70%|███████   | 21/30 [06:06<02:34, 17.15s/it]

Epoch [71|80] 
 Training loss: 0.22351 - Train acc: 0.91129. 
 Validation loss: 0.77208 - Val acc: 0.73984 


Epoch: 100%|██████████| 30/30 [08:40<00:00, 17.33s/it]


In [None]:
ROOT_DIR = os.path.abspath(".")
MODEL_WEIGHT_PATH = os.path.join(ROOT_DIR, "model_weights.pth")

def load_model():
    model = FoodCNN()
    state_dict = torch.load(MODEL_WEIGHT_PATH, weights_only=True)
    model.load_state_dict(state_dict)
    model.eval()
    return model

model = load_model()
torch.save(model, map)

### 4. Transfer learning with MobileNet to recognize more foods

In [1]:
# Download more dataset from kaggle
import kagglehub

# Download latest version
path = kagglehub.dataset_download("quandang/vietnamese-foods")

print("Path to dataset files:", path)

  from .autonotebook import tqdm as notebook_tqdm


Downloading from https://www.kaggle.com/api/v1/datasets/download/quandang/vietnamese-foods?dataset_version_number=11...


100%|██████████| 4.17G/4.17G [21:39<00:00, 3.45MB/s] 

Extracting files...





Path to dataset files: /Users/mac/.cache/kagglehub/datasets/quandang/vietnamese-foods/versions/11


In [15]:
import shutil
ORIGINAL_IMAGE_PATH = "/Users/mac/.cache/kagglehub/datasets/quandang/vietnamese-foods/versions/11/Images"
NEW_DS_DIR = os.path.join(ROOT_DIR, '30_foods_dataset')
if not os.path.exists(NEW_DS_DIR):
    os.mkdir(NEW_DS_DIR)
# Move to current folder
for dataset_type in os.listdir(ORIGINAL_IMAGE_PATH):
    original_dataset_path = os.path.join(ORIGINAL_IMAGE_PATH, dataset_type)
    if 'val' in dataset_type.lower():
        new_dataset_path = os.path.join(NEW_DS_DIR, "val")
    elif 'train' in dataset_type.lower():
        new_dataset_path = os.path.join(NEW_DS_DIR, "train")
    else:
        new_dataset_path = os.path.join(NEW_DS_DIR, "test")

    shutil.copytree(original_dataset_path, new_dataset_path)

In [9]:
from torchvision import datasets, models, transforms

In [78]:
data_dir = NEW_DS_DIR
data_transforms = {
    "train": transforms.Compose(
        [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

image_datasets = {
    x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
    for x in ["train", "val"]
}
dataloaders = {
    x: torch.utils.data.DataLoader(
        image_datasets[x], batch_size=128, shuffle=True
    )
    for x in ["train", "val"]
}
dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "val"]}
class_names = image_datasets["train"].classes

device = (
    torch.accelerator.current_accelerator().type
    if torch.accelerator.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [79]:
class FoodMobileNet(nn.Module):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.feature_extractor = models.mobilenet_v3_small(
            weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
        )
    
        # Freeze model weights
        for param in self.feature_extractor.parameters():
            param.requires_grad = False
        num_ftrs = self.feature_extractor.classifier[-1].in_features
        # Replace head 
        self.feature_extractor.classifier = self.feature_extractor.classifier[:-1]
        self.fc = nn.Linear(num_ftrs, len(class_names))

    def forward(self, X):
        X = self.feature_extractor(X)
        X = self.fc(X)
        return X


model = FoodMobileNet().to(device=device, dtype=torch.float32)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = torch.optim.Adam(model.parameters(), lr=1e-3)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [80]:
from tqdm import tqdm

In [83]:
from tempfile import TemporaryDirectory
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time()

    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, "best_model_params.pt")

        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f"Epoch {epoch}/{num_epochs - 1}")
            print("-" * 10)

            # Each epoch has a training and validation phase
            for phase in ["train", "val"]:
                if phase == "train":
                    model.train()  # Set model to training mode
                else:
                    model.eval()  # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0

                # Iterate over data.
                for inputs, labels in tqdm(dataloaders[phase], desc=phase):
                    inputs = inputs.to(device, dtype = torch.float32)
                    labels = labels.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == "train"):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == "train":
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                if phase == "train":
                    scheduler.step()

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.float()/ dataset_sizes[phase]

                print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

                # deep copy the model
                if phase == "val" and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)

            print()

        time_elapsed = time() - since
        print(
            f"Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s"
        )
        print(f"Best val Acc: {best_acc:4f}")

        # load best model weights
        model.load_state_dict(torch.load(best_model_params_path, weights_only=True))
    return model

In [85]:
trained_model = train_model(model, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=10)

Epoch 0/9
----------


train: 100%|██████████| 138/138 [02:51<00:00,  1.24s/it]


train Loss: 1.7300 Acc: 0.5334


val: 100%|██████████| 20/20 [00:27<00:00,  1.35s/it]


val Loss: 1.3975 Acc: 0.6191

Epoch 1/9
----------


train: 100%|██████████| 138/138 [02:54<00:00,  1.26s/it]


train Loss: 1.5819 Acc: 0.5610


val: 100%|██████████| 20/20 [00:28<00:00,  1.40s/it]


val Loss: 1.3333 Acc: 0.6274

Epoch 2/9
----------


train: 100%|██████████| 138/138 [02:53<00:00,  1.26s/it]


train Loss: 1.5111 Acc: 0.5751


val: 100%|██████████| 20/20 [00:28<00:00,  1.40s/it]


val Loss: 1.2934 Acc: 0.6330

Epoch 3/9
----------


train: 100%|██████████| 138/138 [02:54<00:00,  1.26s/it]


train Loss: 1.4629 Acc: 0.5865


val: 100%|██████████| 20/20 [00:28<00:00,  1.41s/it]


val Loss: 1.2686 Acc: 0.6433

Epoch 4/9
----------


train: 100%|██████████| 138/138 [02:55<00:00,  1.27s/it]


train Loss: 1.4458 Acc: 0.5934


val: 100%|██████████| 20/20 [00:28<00:00,  1.40s/it]


val Loss: 1.2413 Acc: 0.6517

Epoch 5/9
----------


train: 100%|██████████| 138/138 [02:53<00:00,  1.26s/it]


train Loss: 1.4207 Acc: 0.5929


val: 100%|██████████| 20/20 [00:27<00:00,  1.39s/it]


val Loss: 1.2313 Acc: 0.6517

Epoch 6/9
----------


train: 100%|██████████| 138/138 [02:53<00:00,  1.26s/it]


train Loss: 1.3872 Acc: 0.6041


val: 100%|██████████| 20/20 [00:28<00:00,  1.42s/it]


val Loss: 1.2246 Acc: 0.6525

Epoch 7/9
----------


train: 100%|██████████| 138/138 [02:52<00:00,  1.25s/it]


train Loss: 1.3890 Acc: 0.6038


val: 100%|██████████| 20/20 [00:27<00:00,  1.36s/it]


val Loss: 1.2260 Acc: 0.6501

Epoch 8/9
----------


train: 100%|██████████| 138/138 [02:50<00:00,  1.23s/it]


train Loss: 1.3726 Acc: 0.6056


val: 100%|██████████| 20/20 [00:27<00:00,  1.38s/it]


val Loss: 1.2204 Acc: 0.6521

Epoch 9/9
----------


train: 100%|██████████| 138/138 [02:51<00:00,  1.24s/it]


train Loss: 1.3680 Acc: 0.6095


val: 100%|██████████| 20/20 [00:27<00:00,  1.38s/it]

val Loss: 1.2178 Acc: 0.6505

Training complete in 33m 29s
Best val Acc: 0.652485





In [86]:
PATH = "mobilenet_weights.pth"
torch.save(model.state_dict(), PATH)