In [1]:
import torch
import numpy as np
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import PIL.Image as Image
import transforms as T
from engine import train_one_epoch, evaluate
import utils
from MathExpressionDataset import MEdataset
import os


In [2]:

def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [3]:
# load a model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 109 #108 LaTeX symbols + the background/nothing

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

In [4]:
trainDir = r'C:\Users\maxwe\Desktop\My Documents\MathExprSolverMx\MathExprSolverMx\AidaCalculusHandWrittenMathDataset\archive\train'
testDir = r'C:\Users\maxwe\Desktop\My Documents\MathExprSolverMx\MathExprSolverMx\AidaCalculusHandWrittenMathDataset\archive\test'


In [5]:
# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


# use our dataset and defined transformations
dataset = MEdataset(trainDir, get_transform(train=True))
dataset_test = MEdataset(testDir, get_transform(train=False))

# split the dataset in train and test set
indicesTrain = torch.randperm(len(dataset)).tolist()
indicesTest = torch.randperm(len(dataset_test)).tolist()
dataset = torch.utils.data.Subset(dataset, indicesTrain[:])
dataset_test = torch.utils.data.Subset(dataset_test, indicesTest[:])


In [6]:
# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=4,
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn)

In [7]:
# move model to the right device
model.to(device)

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [8]:
# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)
# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

In [None]:
# let's train it for 10 epochs
num_epochs = 1

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

Epoch: [0]  [  0/500]  eta: 0:34:03  lr: 0.000015  loss: 7.1823 (7.1823)  loss_classifier: 4.6797 (4.6797)  loss_box_reg: 0.9286 (0.9286)  loss_objectness: 1.4138 (1.4138)  loss_rpn_box_reg: 0.1603 (0.1603)  time: 4.0879  data: 2.4531  max mem: 1911
Epoch: [0]  [ 10/500]  eta: 0:06:12  lr: 0.000115  loss: 7.1823 (7.3501)  loss_classifier: 4.5375 (4.4591)  loss_box_reg: 0.9360 (0.9436)  loss_objectness: 1.4138 (1.4139)  loss_rpn_box_reg: 0.3599 (0.5335)  time: 0.7610  data: 0.2247  max mem: 2835
Epoch: [0]  [ 20/500]  eta: 0:04:55  lr: 0.000215  loss: 5.1327 (5.8713)  loss_classifier: 3.9219 (3.7098)  loss_box_reg: 0.9349 (0.9153)  loss_objectness: 0.1264 (0.7840)  loss_rpn_box_reg: 0.3490 (0.4622)  time: 0.4429  data: 0.0022  max mem: 2835
Epoch: [0]  [ 30/500]  eta: 0:04:23  lr: 0.000315  loss: 3.4746 (5.0723)  loss_classifier: 1.8772 (3.1038)  loss_box_reg: 0.9200 (0.9263)  loss_objectness: 0.0865 (0.5611)  loss_rpn_box_reg: 0.4093 (0.4811)  time: 0.4488  data: 0.0013  max mem: 2846


In [10]:

savePath = os.path.join(os.getcwd(), 'model1.pt')
torch.save(model.state_dict(), savePath)