In [2]:
from data_loader import OpenImagesDataset
from model_utils import plot_tensor
from model_transformations import Transformations
from torch.utils.data import DataLoader
from params import BATCH_SIZE, DEVICE

In [3]:
from torch import nn
from torchvision import models
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

class DogDetectorModel(nn.Module):
    # CITATION: https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
    
    def __init__(self, modelPath=None):
        super(DogDetectorModel, self).__init__()
        # We only have 2 classes (dog and background)
        numClasses = 2
        
        # load a model pre-trained on COCO
        self.model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
        
        # get number of input features for the classifier
        inFeatures = self.model.roi_heads.box_predictor.cls_score.in_features
        
        # Rejigging the output to have the correct number of classifying features
        self.model.roi_heads.box_predictor = FastRCNNPredictor(inFeatures, numClasses)

    def forward(self, modelInput):

        # Returning the model output
        return self.model(modelInput)

In [4]:
model = DogDetectorModel()

In [5]:
# Reading in the training data
trainingData = OpenImagesDataset(rootDirectory='open-images-v6',
                                 transform=Transformations, 
                                 dataType='train')    

# Defining the training data
trainDataLoader = DataLoader(dataset=trainingData, 
                             batch_size=1,
                             num_workers=8,
                             shuffle=False)

In [7]:
model = model.to(DEVICE)

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)

In [None]:
trainingData = OpenImagesDataset(rootDirectory='open-images-v6', transform=Transformations,dataType='train')    
image, boxes = trainingData.__getitem__(0)
plot_tensor(image,boxes.reshape(1,-1))

In [None]:
image, boxes = trainingData.__getitem__(0)

In [None]:
plot_tensor(image,boxes.reshape(1,-1))

In [None]:
boxes.reshape(1,-1)