In [None]:
# utils
from src.model_managers.standard_model_manager import StandardModelManager
from src.model_managers.standard_model_manager import FRCNNModelManager
from tqdm import tqdm, tqdm_notebook
import matplotlib as plt
import pandas as pd
import numpy as np
import time
import os

# torch
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import torch.optim as optim
import torch.nn as nn
import torchvision
import torch

# transfomers
from transformers import BertTokenizer, BertForQuestionAnswering
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# load data
from src.dataset_loaders.fruits360 import Fruits360Loader
from src.dataset_loaders.download_openimages import OpenImagesLoader

# set device
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'

print(f"Device being used: {device}")

Device being used: cpu


In [4]:
# get data
fl = Fruits360Loader(random_seed=101,
                     batch_size=128,
                     perc_keep=1.0)
train_fl, val_fl, test_fl = fl.load_data()

oil = OpenImagesLoader(random_seed=101,
                       batch_size=128,
                       perc_keep=1.0,
                       num_images_per_class=500,
                       annotation_format='pascal')
# oil.download_data()
# oil.split_data()
train, val, test = oil.get_datasets()

In [5]:
# get/create model
def get_model(num_classes):
    # model types: fasterrcnn_resnet50_fpn,
    #              fasterrcnn_resnet50_fpn_v2,
    #              fasterrcnn_mobilenet_v3_large_fpn,
    #              fasterrcnn_mobilenet_v3_large_320_fpn
    model = fasterrcnn_resnet50_fpn_v2(weights="DEFAULT")

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    return model
        
model = get_model(num_classes=138)

In [6]:
# train and evaluate model
lr = 0.001
epochs = 10
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=lr)
smm = FRCNNModelManager(model=model, 
                        criterion=criterion, 
                        optimizer=optimizer,
                        device=device)

In [7]:
# print('fruits 360')
# for idx, (data, target) in enumerate(train_fl):
#     if idx < 2:
#         print(f"{idx} {len(data)} {len(target)}")
#         if idx == 1:
#             break

# print('\nopenimages')
# for idx, (img, target) in enumerate(train):
#     if idx < 2:
#         print(f"{idx} {len(img)} {len(target)}")
#         if idx == 1:
#             break
# print(f"\nopenimages:\n{len(train)}, {len(val)}, {len(test)}\nfruit360:\n{len(train_fl)}, {len(val_fl)}, {len(test_fl)}")

# targets = []
# for idx, (data, target) in enumerate(train):
#     # Train Batch
#     targ = {"labels": target}
#     targ["boxes"] = smm.generate_hardcoded_boxes(targ["labels"])
#     targets.append(targ)
#     #roi = smm.prepare_targets(targ["labels"])
#     #targets.append(roi)
#     if idx == 1:
#         break
# targets

In [None]:
smm.train(training_data_loader=train_fl,
          validation_data_loader=val,
          epochs=epochs,
          has_box=False)

  0%|          | 0/10 [00:00<?, ?it/s]