In [None]:
import torch
from torch.optim import Adam
import torch.nn as nn 
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from glob import glob
from torch_snippets import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class ImageDataset(Dataset):
  w_and_h = (224,224)
  def __init__(self, df, root_directory, transformer = None):
    self.df = df
    self.files = glob(root_directory + "/*.jpg")
    self.transformer = transformer
    self.label2name = {l:t+1 for t, l in enumerate(self.df('LabelName').unique())}
    self.label2name['background'] = 0
    self.name2label = {t:l for l, t in self.label2name.items()}
    self.num_of_classes = len(self.label2name)
    self.image_id = self.df.ImageID.unique()
  
  def collate_fn(self, batch):
    return tuple(zip(*batch))

  def __len__(self):
    return len(self.image_id)

  def __getitem__(self, index):
    #Loading image_id and corresponding image
    image_id = self.image_id[index]
    image_path = find(image_id, self.files)
    img = Image.open(image_path).convert("RGB")
    if self.transformer:
      img = self.transformer(img).permute(2,0,1)
    else:
      img = torch.tensor(img).permute(2,0,1)
    
    #Grabbing labels and box data
    data = self.df[self.df['ImageID'] == image_id]
    labels = data['LabelName'].values.tolist()
    print(labels)
    data = data[['XMin', 'YMin','XMax','YMAX']].values

    #Moving coordinates into absolute coordinate values
    data[:, [0,2]] *= self.w_and_h[0]
    data[:, [1,3]] *= self.w_and_h[1]

    boxes = data.astype(np.unit32).tolist()

    target = {}
    target['boxes'] = torch.tensor(boxes).float().to(device)
    target['labels'] = torch.tensor([self.label2target[i] for i in labels]).long().to(device)

    return img.to(device), target

In [None]:
def get_model(num_of_classes):
  model = torchvision.models.detection.fasterrcnn_resenet50_fpn(pretrained = True)
  in_features = model.roi_heads.box_predictor.cls_score.in_features
  model.roi_heads.box_predictor = FastRCNNPredictor(in_features,num_of_classes)
  return model

def train(inputs, model, optimizer):
  model.train()
  input, targets = inputs
  optimizer.zero_grad()
  losses = model(input, targets)
  loss = sum(loss for loss in losses.value())
  loss.backward()
  optimizer.step()
  return loss, losses

def train_epoch(n_epochs, model, optimizer)

In [None]:
if __name__ == "main":
  training_directory = ""
  validation_directory = ""
  batch_size = ""
  df_raw = ""
  batch_size = 16
  n_epochs = 100

  tr_transformer = transforms.Compose([
      transforms.Resize((224,224)),
      transforms.RandomHorizontalFlip(0.2),
      transforms.RandomVerticalFlip(0.2),
      transforms.ToTensor(),
      transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
  ])

  vl_transformer = transforms.Compose([
      transforms.Resize((224,224)),
      transforms.ToTensor()
  ])

  print("Loading Dataset: Training")
  tr_set = ImageDataset(df_raw, training_directory, tr_transformer)
  print("Loading Data to Loader: Training")
  tr_dl = DataLoader(tr_set,batch_size = batch_size, shuffle = True, drop_last = True, collate_fn=tr_set.collate_fn())
  
  
  model = get_model().to(device)
  optimizer = Adam(model.parameters(),lr = 1e-3,momentum = 0.9,weight_decay=0.0005)
  log = Report(n_epochs)


