In [1]:
# utils
from src.model_managers.standard_model_manager import StandardModelManager
from src.model_managers.standard_model_manager import FRCNNModelManager
from tqdm import tqdm, tqdm_notebook
import matplotlib as plt
import numpy as np
import time
import os

# torch
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import torch.optim as optim
import torch.nn as nn
import torchvision
import torch

# transfomers
from transformers import BertTokenizer, BertForQuestionAnswering
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# load data
from src.dataset_loaders.fruits360 import Fruits360Loader

# set device
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'

print(f"Device being used: {device}")

Device being used: cpu


In [2]:
# get data
fl = Fruits360Loader(random_seed=101,
                     batch_size=128,
                     perc_keep=1.0)
train, val, test = fl.load_data()

In [3]:
# get/create model
def get_model(num_classes):
    # model types: fasterrcnn_resnet50_fpn,
    #              fasterrcnn_resnet50_fpn_v2,
    #              fasterrcnn_mobilenet_v3_large_fpn,
    #              fasterrcnn_mobilenet_v3_large_320_fpn
    
    # model = fasterrcnn_resnet50_fpn_v2(pretrained=True)
    model = fasterrcnn_resnet50_fpn_v2(weights="DEFAULT")

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    return model
        
model = get_model(num_classes=138)

In [4]:
# train and evaluate model
lr = 0.001
epochs = 10
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=lr)
smm = FRCNNModelManager(model=model, 
                        criterion=criterion, 
                        optimizer=optimizer,
                        device=device)

In [5]:
for idx, (data, target) in enumerate(train):
    if idx < 2:
        print(f"idx: {idx}\nlen data: {len(data)}\n len target: {len(target)}")
        if idx == 1:
            break

idx: 0
len data: 128
 len target: 128
idx: 1
len data: 128
 len target: 128


In [None]:
smm.train(training_data_loader=train,
          validation_data_loader=val,
          epochs=epochs)