In [1]:
# install dataset
import kagglehub
import os
import shutil
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

import torchvision
from PIL import Image
from torchvision.models.vision_transformer import *
from torchvision.models.swin_transformer import *
from torchvision.models.resnet import *

In [None]:
kagglehub.login()

In [None]:
path = kagglehub.dataset_download("harishkumardatalab/food-image-classification-dataset")
if not os.path.exists("./data"):
    os.makedirs("./data", exist_ok=True)
    shutil.move(path, "./data")

Maximum number of images in a class is 1500. So let us consider `Taco` and `Baked Potato` classes

In [None]:
for dir in os.scandir("./data/1/Food Classification dataset"):
    if dir.name not in ["Taco", "Baked Potato"]:
        shutil.rmtree(dir.path)

In [2]:
# make a class for dataset

ID_TO_LABEL = {
    0: "Baked Potato",
    1: "Taco"
}


class Dataset(torchvision.datasets.VisionDataset):
    def __init__(self, root, transform=None):
        super().__init__(transform)
        self.root = root
        self.transform = transform
        self.samples = []
        for c_id, c in enumerate(["Baked Potato", "Taco"]):
            c_path = os.path.join(self.root, c)
            for file in os.scandir(c_path):
                self.samples.append((file.path, c_id))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        path, label = self.samples[index]
        image = Image.open(path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label


In [3]:
root = os.path.join(os.getcwd(), "data", "1", "Food Classification dataset")

transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
            (0.485, 0.456, 0.406),
            (0.229, 0.224, 0.225)
        ) # ImageNet stats
])

dataset = Dataset(root, transform=transforms)

In [4]:
# calculate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
# load pre-trained ViT, SWin and ResNet models
vit_model = vit_b_16(pretrained=True)
vit_model.heads.head = torch.nn.Linear(in_features=vit_model.heads.head.in_features, out_features=2)
swin_model = swin_s(pretrained=True)
swin_model.head = torch.nn.Linear(in_features=swin_model.head.in_features, out_features=2)
resnet_model = resnet152(pretrained=True)
resnet_model.fc = torch.nn.Linear(in_features=resnet_model.fc.in_features, out_features=2)

models = [vit_model]



In [6]:
# train, validation and test split
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = int(0.15 * len(dataset))

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

for model in models:

    model = model.to(device)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=64, shuffle=True, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=True, pin_memory=True)

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(model.parameters(), lr=7e-4, weight_decay=0.001)

    loss = torch.nn.CrossEntropyLoss()
    scaler = torch.amp.GradScaler("cuda")


    # training loop
    num_epochs = 100

    best_loss = float("inf")
    wait = 0
    patience = 5

    for epoch in range(num_epochs):
        print(epoch)
        model.train()
        running_loss = 0.0
        running_correct = 0
        total = 0

        for batch in tqdm(train_loader):
            optimizer.zero_grad()

            images, labels = batch
            images = images.to(device)
            labels = labels.to(device)

            with torch.amp.autocast('cuda'):
                outputs = model(images)
                l = loss(outputs, labels)
                # l.backward()
            scaler.scale(l).backward()
            scaler.step(optimizer)
            scaler.update()

            # if grad mean is less than 1e-4, then something is wrong
            grad_mean = sum(p.grad.abs().mean().item() for p in model.parameters() if p.grad is not None)

            if grad_mean < 1e-4:
                print("Grad mean too low!")

            # optimizer.step()

            running_loss += l.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            running_correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / total
        epoch_acc = running_correct / total

        print(f"Train on {epoch + 1}, Loss: {epoch_loss}, Accuracy: {epoch_acc}")

        # validation
        model.eval()

        val_loss = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for batch in tqdm(val_loader):
                images, labels = batch
                images = images.to(device)
                labels = labels.to(device)

                # with torch.amp.autocast('cuda'):
                outputs = model(images)

                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                val_loss += loss(outputs, labels).item() * images.size(0)

        val_epoch_loss = val_loss / val_total
        val_acc = val_correct / val_total

        print(f"Validation on {epoch + 1}, Accuracy: {val_acc}, Loss: {val_epoch_loss}")
        if val_epoch_loss < best_loss:
            best_loss = val_epoch_loss
            wait = 0
        else:
            wait += 1
            # if wait >= patience:
            #     print("Early stopping triggered!")
            #     break

    del model
    torch.cuda.empty_cache()

0


100%|██████████| 33/33 [00:22<00:00,  1.48it/s]


Train on 1, Loss: 0.978432362874349, Accuracy: 0.5157142857142857


100%|██████████| 8/8 [00:03<00:00,  2.45it/s]


Validation on 1, Accuracy: 0.5, Loss: 0.6968730923864577
1


100%|██████████| 33/33 [00:21<00:00,  1.51it/s]


Train on 2, Loss: 0.7089546128681727, Accuracy: 0.4938095238095238


100%|██████████| 8/8 [00:03<00:00,  2.42it/s]


Validation on 2, Accuracy: 0.49777777777777776, Loss: 0.6971270434061686
2


100%|██████████| 33/33 [00:21<00:00,  1.52it/s]


Train on 3, Loss: 0.7121211171150208, Accuracy: 0.4895238095238095


100%|██████████| 8/8 [00:03<00:00,  2.44it/s]


Validation on 3, Accuracy: 0.5, Loss: 0.6938474729326036
3


100%|██████████| 33/33 [00:22<00:00,  1.47it/s]


Train on 4, Loss: 0.7184256411734081, Accuracy: 0.5114285714285715


100%|██████████| 8/8 [00:03<00:00,  2.40it/s]


Validation on 4, Accuracy: 0.5, Loss: 0.7140969758563571
4


 24%|██▍       | 8/33 [00:05<00:16,  1.48it/s]


KeyboardInterrupt: 

In [8]:
del model
torch.cuda.empty_cache()