In [1]:
import os
import time
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.models import vit_b_16, vit_l_16
from torchvision.models import ViT_B_16_Weights
from src.model_managers.standard_model_manager import StandardModelManager

from src.dataset_loaders.download_openimages import OpenImagesLoader
from tqdm import tqdm, tqdm_notebook


from transformers import AutoImageProcessor, DetrForObjectDetection
from PIL import Image




In [2]:
# Device Configuration:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

print(f"Device being used: {device}")

Device being used: cuda


In [3]:
### Loading Open Images Dataset:


# Data Configuration & Hyperparameters:
PERC_KEEP = 0.25 # Proportion of data from datasets to keep
BATCH_SIZE = 128 # Batch size
EPOCHS = 10
LEARNING_RATE = 5e-4

oi_loader = OpenImagesLoader(batch_size=BATCH_SIZE, perc_keep=PERC_KEEP)
print(f"Number of classes: {len(oi_loader.classes)}")

oi_loader.download_data()
# fruits360_loader = Fruits360Loader(batch_size=BATCH_SIZE, perc_keep=PERC_KEEP)
# train_360, val_360, test_360 = fruits360_loader.load_data()

# print(f"Number of Batches in Training Set: {len(train_360)}")
# print(f"Number of Batches in Validation Set: {len(val_360)}")
# print(f"Number of Batches in Testing Set: {len(test_360)}")


Number of classes: 64
Attempting to download Hot dog data
Skipped Hot dog, data already downloaded
Attempting to download French fries data
Skipped French fries, data already downloaded
Attempting to download Waffle data
Skipped Waffle, data already downloaded
Attempting to download Pancake data
Skipped Pancake, data already downloaded
Attempting to download Burrito data
Skipped Burrito, data already downloaded
Attempting to download Pretzel data
Skipped Pretzel, data already downloaded
Attempting to download Popcorn data
Skipped Popcorn, data already downloaded
Attempting to download Cookie data
Skipped Cookie, data already downloaded
Attempting to download Muffin data
Skipped Muffin, data already downloaded
Attempting to download Ice cream data
Skipped Ice cream, data already downloaded
Attempting to download Cake data
Skipped Cake, data already downloaded
Attempting to download Candy data
Skipped Candy, data already downloaded
Attempting to download Guacamole data
Skipped Guacamole,

2024-12-06  17:08:11 INFO Downloading 34 train images for class 'garden asparagus'
100%|██████████████████████████████████████████████████████████████████████████████████| 34/34 [00:03<00:00,  8.87it/s]
2024-12-06  17:08:16 INFO Creating 34 train annotations (pascal) for class 'garden asparagus'
100%|██████████████████████████████████████████████████████████████████████████████████| 34/34 [00:04<00:00,  7.80it/s]
2024-12-06  17:08:22 INFO Downloading 5 validation images for class 'garden asparagus'
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.23it/s]
2024-12-06  17:08:23 INFO Creating 5 validation annotations (pascal) for class 'garden asparagus'
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  3.06it/s]
2024-12-06  17:08:28 INFO Downloading 11 test images for class 'garden asparagus'
100%|███████████████████████████████████████████████████████████████████████

Attempting to download Pumpkin data
Skipped Pumpkin, data already downloaded
Attempting to download Zucchini data
Skipped Zucchini, data already downloaded
Attempting to download Cabbage data
Skipped Cabbage, data already downloaded
Attempting to download Carrot data
Skipped Carrot, data already downloaded
Attempting to download Salad data
Skipped Salad, data already downloaded
Attempting to download Broccoli data
Skipped Broccoli, data already downloaded
Attempting to download Bell pepper data
Skipped Bell pepper, data already downloaded
Attempting to download Winter melon data
Skipped Winter melon, data already downloaded
Attempting to download Honeycomb data
Skipped Honeycomb, data already downloaded
Attempting to download Hamburger data
Skipped Hamburger, data already downloaded
Attempting to download Submarine sandwich data
Skipped Submarine sandwich, data already downloaded
Attempting to download Cheese data
Skipped Cheese, data already downloaded
Attempting to download Milk data

In [None]:
# Loading DETR Resnet-50 Model from HuggingFace:

img_proc = AutoImageProcessor.from_pretrained('facebook/detr-resnet-50')
detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

inputs = img_proc(
outputs = detr_model(**inputs)



In [None]:
# Loading vit_b_16 model with pre-trained weights on the ImageNet dataset:
vit_b = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)

# Number of classes of the Fruit360 dataset:
num_classes = 141

# Adjusting the last layer of the transformer to perform classification on the Fruits360 dataset:
vit_b.heads.head = nn.Linear(in_features=768, out_features=num_classes)

# Freezing the architecture:
for param in vit_b.parameters():
    param.requires_grad = False

# Unfreezing the architecture in the last layer to fine-tune model:
for param in vit_b.heads.head.parameters():
    param.requires_grad = True

# Model Training Configuration:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(vit_b.parameters(), lr=LEARNING_RATE)

# Wrapping the model in the StandardModelManager:
vit_b_wrapper = StandardModelManager(model=vit_b, criterion=criterion, optimizer=optimizer, device=device)


In [None]:
# Training the model:
vit_b_wrapper.train(training_data_loader=train_360, validation_data_loader=val_360, epochs=EPOCHS)

# Creating, saving, and displaying learning curve from training:
vit_b_wrapper.plot_learning_curve("vit_b_16")

# Testing the model:
vit_b_wrapper.test(test_360)