In [None]:
import torch
BATCH_SIZE = 4 # increase / decrease according to GPU memeory
RESIZE_TO = 512 # resize the image for training and transforms
NUM_EPOCHS = 10 # number of epochs to train for
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

image_dir = 'drive/MyDrive/DetectionDataset/JPEGImages'
annotation_dir = 'drive/MyDrive/DetectionDataset/Annotations'

classes = [
    'background', 'fore'
]

num_classes = 2

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import cv2
import numpy as np
import os
import glob as glob
from xml.etree import ElementTree as et
#from config import CLASSES, RESIZE_TO, TRAIN_DIR, VALID_DIR, BATCH_SIZE
from torch.utils.data import Dataset, DataLoader
#from utils import collate_fn, get_train_transform, get_valid_transform
from PIL import Image

In [None]:
from torchvision.transforms import Compose, Resize, ToTensor

In [None]:
!pip install albumentations==0.4.6

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
def get_transform():
  return A.Compose(
        [
         A.Flip(0.5),
         ToTensorV2(p=1.0),
        ],
        bbox_params={
        'format': 'pascal_voc',
        'label_fields': ['labels']
    })

In [None]:
def collate_fn(batch):
    """
    To handle the data loading as different images may have different number 
    of objects and to handle varying size tensors as well.
    """
    return tuple(zip(*batch))

In [None]:
from matplotlib import cm

In [None]:
class AcneDetectDataset(Dataset):
    def __init__(self, dir_path_image, dir_path_ann, width, height, classes, transforms=None):
        self.transforms = transforms
        self.dir_path_image = dir_path_image
        self.dir_path_ann = dir_path_ann
        self.height = height
        self.width = width
        self.classes = classes

        self.image_paths = glob.glob(f"{self.dir_path_image}/*.jpg")
        self.annotayions = glob.glob(f"{self.dir_path_ann}/*.xml")

    def __getitem__(self, idx):
        # capture the image name and the full image path
        image_name = self.image_paths[idx].split('/')[-1]
        image_path = self.image_paths[idx]
        # read the image
        image = cv2.imread(image_path)
        image_target = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image_target = cv2.resize(image_target, (self.width, self.height))
        image_target /= 255.0
        
        # capture the corresponding XML file for getting the annotations
        annot_filename = image_name[:-4] + '.xml'
        
        boxes = []
        labels = []
        tree = et.parse(self.dir_path_ann + '/' + annot_filename)
        root = tree.getroot()
        
        # get the height and width of the image
        image_width = image.shape[1]
        image_height = image.shape[0]

        for member in root.findall('object'):
            # map the current object name to `classes` list to get...
            # ... the label index and append to `labels` list
            labels.append(self.classes.index(member.find('name').text))
            
            # xmin = left corner x-coordinates
            xmin = int(member.find('bndbox').find('xmin').text)
            # xmax = right corner x-coordinates
            xmax = int(member.find('bndbox').find('xmax').text)
            # ymin = left corner y-coordinates
            ymin = int(member.find('bndbox').find('ymin').text)
            # ymax = right corner y-coordinates
            ymax = int(member.find('bndbox').find('ymax').text)
            
            # resize the bounding boxes according to the...
            # ... desired `width`, `height`
            xmin_final = (xmin/image_width)*self.width
            xmax_final = (xmax/image_width)*self.width
            ymin_final = (ymin/image_height)*self.height
            yamx_final = (ymax/image_height)*self.height
            
            boxes.append([xmin_final, ymin_final, xmax_final, yamx_final])
        
        # bounding box to tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # area of the bounding boxes
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # no crowd instances
        iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
        # labels to tensor
        labels = torch.as_tensor(labels, dtype=torch.int64)
        # prepare the final `target` dictionary
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["area"] = area
        target["iscrowd"] = iscrowd
        image_id = torch.tensor([idx])
        target["image_id"] = image_id
        # apply the image transforms
        transform = get_transform()
        sample = transform(image = image_target,
                                     bboxes = target['boxes'],
                                     labels = labels)
        image_target = sample['image']
        target['boxes'] = torch.Tensor(sample['bboxes'])
            
        return image_target, target

    def __len__(self):
        return len(self.image_paths)

In [None]:
from torch.utils.data import DataLoader

In [None]:
dataset = AcneDetectDataset(image_dir, annotation_dir, 512, 512, classes)
batch_size = 8
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers= 0, collate_fn = collate_fn)

In [None]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def create_model(num_classes):
    
    # load Faster RCNN pre-trained model
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
    
    # get the number of input features 
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # define a new head for the detector with required number of classes
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 
    return model

In [None]:
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import time

In [None]:
def train(train_data_loader, model, optimizer):
    print('Training')
    
     # initialize tqdm progress bar
    prog_bar = tqdm(train_data_loader, total=len(train_data_loader))
    
    for i, data in enumerate(prog_bar):
        optimizer.zero_grad()
        images, targets = data
        
        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        losses.backward()
        optimizer.step()
        # update the loss value beside the progress bar for each iteration
        prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")

In [None]:
model = create_model(num_classes= num_classes)
model = model.to(DEVICE)
#params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
num_epochs = 15
for epoch in range(num_epochs):
  train_loss = train(train_loader, model, optimizer)

In [None]:
torch.save(model.state_dict(), 'model.pth')

In [None]:
 image_paths = glob.glob(f"{image_dir}/*.jpg")[:8]
 image_paths

In [None]:
from google.colab.patches import cv2_imshow

In [None]:
detection_threshold = 0.5
test_images = image_paths
model = model.eval()
for i in range(len(test_images)):
    # get the image file name for saving output later on
    image_name = test_images[i].split('/')[-1].split('.')[0]
    image = cv2.imread(test_images[i])
    orig_image = image.copy()
    # BGR to RGB
    image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB).astype(np.float32)
    # make the pixel range between 0 and 1
    image /= 255.0
    # bring color channels to front
    image = np.transpose(image, (2, 0, 1)).astype(np.float)
    # convert to tensor
    image = torch.tensor(image, dtype=torch.float).cuda()
    # add batch dimension
    image = torch.unsqueeze(image, 0)
    with torch.no_grad():
        outputs = model(image)
    
    # load all detection to CPU for further operations
    outputs = [{k: v.to('cpu') for k, v in t.items()} for t in outputs]
    # carry further only if there are detected boxes
    if len(outputs[0]['boxes']) != 0:
        boxes = outputs[0]['boxes'].data.numpy()
        scores = outputs[0]['scores'].data.numpy()
        # filter out boxes according to `detection_threshold`
        boxes = boxes[scores >= detection_threshold].astype(np.int32)
        draw_boxes = boxes.copy()
        # get all the predicited class names
        pred_classes = classes
        
        # draw the bounding boxes and write the class name on top of it
        for j, box in enumerate(draw_boxes):
            cv2.rectangle(orig_image,
                        (int(box[0]), int(box[1])),
                        (int(box[2]), int(box[3])),
                        (0, 0, 255), 2)
            cv2.putText(orig_image, 'acne', 
                        (int(box[0]), int(box[1]-5)),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 
                        2, lineType=cv2.LINE_AA)
        cv2_imshow(orig_image)

    print(f"Image {i+1} done...")
    print('-'*50)
print('TEST PREDICTIONS COMPLETE')