# Training

In [None]:
#import necessary libraries
import os
current_directory = os.getcwd()
print(current_directory)
import numpy as np
import torch
from PIL import Image

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import albumentations as A
from training.datatools import COCOObjectDetectionDataset, get_box_transform

from training.engine import train_one_epoch, evaluate, validate
import utils
import torchvision.transforms as tra
from torchvision.transforms import functional as F
import matplotlib.pyplot as plt
import cv2
#import generate_masks as gm
import random

## Load up the model

In [None]:

#define the number of class
num_classes=2 #background and grape
# load an instance segmentation model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)


## Create the Dataset Object and dataloader

In [None]:
import torchvision.datasets as dset

# the dataset object
img_path="./images"
json_path="./trainval_detection.json"
import albumentations as A

from training.datatools import COCOObjectDetectionDataset, get_box_transform


#you can add your own albumentations transforms here 

transform_list = A.Compose([
        #A.Resize(200, 300),
        #A.CenterCrop(100, 100),
        #A.RandomCrop(80, 80),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=(-6, 6), p=0.6),
        #A.ShiftScaleRotate(rotate_limit=50, p=0.6)
        #A.SmallestMaxSize(max_size=1292, p=1)
        #A.augmentations.transforms.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.15, rotate_limit=5, p=0.5)
        A.RandomScale(scale_limit=0.2, p=0.5)
        #A.core.composition.OneOf([A.Resize(2298, 1292, interpolation=1, always_apply=False, p=1), A.Resize(3109, 1748, interpolation=1, always_apply=False, p=1), A.Resize(2568, 1444, interpolation=1, always_apply=False, p=1), A.Resize(2839, 1596, interpolation=1, always_apply=False, p=1)],p=0.5)
        #A.augmentations.transforms.Resize(2298, 1292, interpolation=1, always_apply=False, p=1)
        #A.VerticalFlip(p=0.5),
        #A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))


trainval = COCOObjectDetectionDataset(root = img_path,json_dir = json_path, transforms=transform_list, train=True)


# trainval = dset.CocoDetection(root = img_path,annFile = json_path)#dataset_test = GrapeBunchDataset('test', get_transform(train=False))

train, val = torch.utils.data.random_split(trainval, [2, 1]) # put 2 images in train and 1 in val for this dummy example



# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(train, batch_size=1, shuffle=True, num_workers=0, collate_fn=utils.collate_fn)

#data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=4, collate_fn=utils.collate_fn)

data_loader_validate = torch.utils.data.DataLoader(train, batch_size=1, shuffle=False, num_workers=0, collate_fn=utils.collate_fn)
    

## Now lets train for 50 epochs

In [None]:
# now start the training process
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.009, momentum=0.9, weight_decay=0.0006)

# and a learning rate scheduler (optional)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=1)  # optional lr scheudler (gamma is 1 by default which means no scheduling)

# let's train it for 50 epochs and save after best epoch
num_epochs = 50

best_val_loss=999
for epoch in range(num_epochs):
    model.train()
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=2)
    # update the learning rate (optional)
    lr_scheduler.step()
    # evaluate and checkpoint using the validate dataset
    best_val_loss=validate(model, data_loader_validate, device=device, best_val_loss=best_val_loss)


print('Trained for '+str(num_epochs) + ' epochs')