In [1]:
import os
import time
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.models import vit_b_16, vit_l_16
from torchvision.models import ViT_B_16_Weights
from src.model_managers.standard_model_manager import StandardModelManager

from src.dataset_loaders.download_openimages import OpenImagesLoader
from tqdm import tqdm, tqdm_notebook


from transformers import AutoImageProcessor, DetrForObjectDetection
from PIL import Image




In [2]:
# Device Configuration:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

print(f"Device being used: {device}")

Device being used: cuda


In [None]:
### Loading Open Images Dataset:


# Data Configuration & Hyperparameters:
PERC_KEEP = 0.10 # Proportion of data from datasets to keep
BATCH_SIZE = 16 # Batch size
EPOCHS = 10
LEARNING_RATE = 5e-4

oi_loader = OpenImagesLoader(batch_size=BATCH_SIZE, perc_keep=PERC_KEEP)
print(f"Number of classes: {len(oi_loader.classes)}")

# Loading the Open Images Dataset:
# oi_loader.download_data()
# oi_loader.split_data(keep_class_dirs=False)
# oi_loader.split_data_reduced(keep_class_dirs=False)
train_set, val_set, test_set = oi_loader.get_dataloaders()

print(f"Number of Batches in Training Set: {len(train_set)}")
print(f"Number of Batches in Validation Set: {len(val_set)}")
print(f"Number of Batches in Testing Set: {len(test_set)}")


Number of classes: 64
Splitting data for class Hot dog
Splitting data for class French fries
Splitting data for class Waffle
Splitting data for class Pancake
Splitting data for class Burrito
Splitting data for class Pretzel
Splitting data for class Popcorn
Splitting data for class Cookie
Splitting data for class Muffin
Splitting data for class Ice cream
Splitting data for class Cake
Splitting data for class Candy
Splitting data for class Guacamole
Splitting data for class Apple
Splitting data for class Grape
Splitting data for class Common fig
Splitting data for class Pear
Splitting data for class Strawberry
Splitting data for class Tomato
Splitting data for class Lemon
Splitting data for class Banana
Splitting data for class Orange
Splitting data for class Peach
Splitting data for class Mango
Splitting data for class Pineapple
Splitting data for class Grapefruit
Splitting data for class Pomegranate
Splitting data for class Watermelon
Splitting data for class Cantaloupe
Splitting data 

In [None]:
# Loading DETR Resnet-50 Model from HuggingFace:

img_proc = AutoImageProcessor.from_pretrained('facebook/detr-resnet-50')
detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

inputs = img_proc(
outputs = detr_model(**inputs)



In [None]:
# Loading vit_b_16 model with pre-trained weights on the ImageNet dataset:
vit_b = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)

# Number of classes of the Fruit360 dataset:
num_classes = 141

# Adjusting the last layer of the transformer to perform classification on the Fruits360 dataset:
vit_b.heads.head = nn.Linear(in_features=768, out_features=num_classes)

# Freezing the architecture:
for param in vit_b.parameters():
    param.requires_grad = False

# Unfreezing the architecture in the last layer to fine-tune model:
for param in vit_b.heads.head.parameters():
    param.requires_grad = True

# Model Training Configuration:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(vit_b.parameters(), lr=LEARNING_RATE)

# Wrapping the model in the StandardModelManager:
vit_b_wrapper = StandardModelManager(model=vit_b, criterion=criterion, optimizer=optimizer, device=device)


In [None]:
# Training the model:
vit_b_wrapper.train(training_data_loader=train_360, validation_data_loader=val_360, epochs=EPOCHS)

# Creating, saving, and displaying learning curve from training:
vit_b_wrapper.plot_learning_curve("vit_b_16")

# Testing the model:
vit_b_wrapper.test(test_360)