## Connecting to Drive

In [2]:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


## Imports

In [18]:
import re
import os
import sys
import cv2
import time
import json
import string
import random
import warnings
import argparse
import progressbar
import numpy as np
import pandas as pd

from tqdm.notebook import tqdm

from fastai.vision import *
from fastai.basic_data import *
from fastai.metrics import accuracy
from fastai.callbacks.hooks import num_features_model

import torch

## Copying Data/Model

In [None]:
!mkdir data
!mkdir models

!cp '/gdrive/My Drive/Semester 8/CV/snek/datasets/train_small_10.tar.gz' ./data
!cp '/gdrive/My Drive/Semester 8/CV/snek/models/snake-detection-model.pt' ./models

!tar -xvf ./data/train_small_10.tar.gz -C ./data/

In [7]:
MODEL_PATH = "./models/snake-detection-model.pt"
DATASET_PATH = "./data/train/"
OUTPUT_PATH = "./data/output/"

## Define Model and Datasets

In [10]:
class SnakeDetector(nn.Module):
    def __init__(self, arch=models.resnet18):
        super().__init__() 
        self.cnn = create_body(arch)
        self.head = create_head(num_features_model(self.cnn) * 2, 4)
        
    def forward(self, im):
        x = self.cnn(im)
        x = self.head(x)
        return x.sigmoid_()

In [11]:
class CustomDataset(Dataset):
    def __init__(self, j, aug=None):
        self.j = j
        if aug is not None: aug = get_aug(aug)
        self.aug = aug
    
    def __getitem__(self, idx):
        item = j2anno(self.j[idx])
        if self.aug: item = self.aug(**item)
        im, bbox = item['image'], np.array(item['bboxes'][0])
        im, bbox = self.normalize_im(im), self.normalize_bbox(bbox)
        
        return im.transpose(2,0,1).astype(np.float32), bbox.astype(np.float32)
    
    def __len__(self):
        return len(self.j)
    
    def normalize_im(self, ary):
        return ((ary / 255 - imagenet_stats[0]) / imagenet_stats[1])
    
    def normalize_bbox(self, bbox):
        return bbox / SZ


## Initialize Model

In [12]:
src = (ImageList.from_folder(path=DATASET_PATH).split_by_rand_pct(0.0).label_from_folder())
tfms = get_transforms(do_flip=True,flip_vert=False,max_rotate=10.0,max_zoom=1.1,max_lighting=0.2,max_warp=0.2,p_affine=0.75,p_lighting=0.75)
data = (src.transform(tfms, size=360, resize_method=ResizeMethod.SQUISH)
        .databunch(bs=32).normalize(imagenet_stats))

  "The default behavior for interpolate/upsample with float scale_factor changed "
  "The default behavior for interpolate/upsample with float scale_factor changed "
  "The default behavior for interpolate/upsample with float scale_factor changed "
  "The default behavior for interpolate/upsample with float scale_factor changed "


In [14]:
learn = Learner(data, SnakeDetector(arch=models.resnet50), loss_func=torch.nn.L1Loss())
learn.split([learn.model.cnn[:6], learn.model.cnn[6:], learn.model.head])

state_dict = torch.load(MODEL_PATH)
learn.model.load_state_dict(state_dict['model'])

<All keys matched successfully>

## Crop Images

In [19]:
src_new = (ImageList.from_folder(path=DATASET_PATH).split_by_rand_pct(0.0).label_from_folder())
str_name = str(src_new.items[0])

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    for filename in tqdm(src_new.items):
        try:
            start = time.time()
            
            im = cv2.imread(f"{filename}", cv2.IMREAD_COLOR)
            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
            im = cv2.resize(im, (360,360), interpolation = cv2.INTER_AREA)
            im_height, im_width, _ = im.shape
            
            orig_im = cv2.imread(f"{filename}", cv2.IMREAD_COLOR)
            orig_im_height, orig_im_width, _ = orig_im.shape
            to_pred = open_image(filename)
            
            _,_,bbox=learn.predict(to_pred)
            
            im_original = cv2.imread(f"{filename}", cv2.IMREAD_COLOR)
            im_original = cv2.cvtColor(im_original, cv2.COLOR_BGR2RGB)
            im_original.shape
            im_original_width = im_original.shape[1]
            im_original_height = im_original.shape[0]
            
            bbox_new = bbox
            bbox_new[0] = bbox_new[0]*im_original_width 
            bbox_new[2]= bbox_new[2]*im_original_width
            bbox_new[1] = bbox_new[1]*im_original_height
            bbox_new[3] = bbox_new[3]*im_original_height
            x_min, y_min, x_max, y_max = map(int, bbox_new)
            
            im_original = im_original[y_min:y_max,x_min:x_max]
            im_original = cv2.cvtColor(im_original,cv2.COLOR_BGR2RGB)
            filename_str = str(filename)
            
            to_save = filename_str.replace('train','cropped_images')
            to_save = to_save.split("/")
            file_name = "/".join(to_save[len(to_save)-2:])
            class_id = OUTPUT_PATH + to_save[-2]
            
            if not os.path.exists(class_id):
                os.makedirs(class_id)
            
            to_save = OUTPUT_PATH + file_name
            cv2.imwrite(to_save, im_original)
            # print("saved", to_save)
            # print('It took', time.time()-start, 'seconds.')
        except Exception as e:
            print(str(e))


HBox(children=(FloatProgress(value=0.0, max=15014.0), HTML(value='')))

OpenCV(4.1.2) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

OpenCV(4.1.2) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

OpenCV(4.1.2) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

OpenCV(4.1.2) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

OpenCV(4.1.2) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

OpenCV(4.1.2) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

OpenCV(4.1.2) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

OpenCV(4.1.2) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.

## Save Data

In [33]:
!for name in ./data/outputclass*; do newname="$(echo "$name" | cut -c14-)"; mv "$name" "$newname"; done

In [34]:
!mkdir ./data/train_cropped
!mv class* ./data/train_cropped

In [None]:
!tar -cvzf train_small_10_cropped.tar.gz ./data/train_cropped

In [38]:
!cp train_small_10_cropped.tar.gz '/gdrive/My Drive/Semester 8/CV/snek/datasets/'

In [40]:
!du -sh train_small_10_cropped.tar.gz

!find ./data/train/ -type f | wc -l
!find ./data/train_cropped/ -type f | wc -l

1.5G	train_small_10_cropped.tar.gz
15014
14981


33 images not cropped because error