In [None]:
import torch 
import matplotlib.pyplot as plt
import torchvision as vision
from torchvision import datasets, transforms
from matplotlib.patches import Circle
from PIL import Image
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
import os
import shutil

In [None]:
#Loading Cat images from a local folder into dataset
TRAIN_DATA_PATH = "PetImages/CatUnlabelled/"
transform = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224)])
train_dataset = datasets.ImageFolder(TRAIN_DATA_PATH, transform=transform)

In [None]:
#
# Feeds an image into the ResNet classifier
# Returns the highest class prediction and its score
#
def guessImage(img):
    # Step 1: Initialize model with the best available weights
    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights)
    model.eval()

    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)

    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    score = prediction[class_id].item()
    category_name = weights.meta["categories"][class_id]
    # print(f"{category_name}: {100 * score:.1f}%")
    return category_name, 100 * score

#
# First iteration of cropping method - superseded due to class overlap problem
#
def createCropsOLD(img):
    zoomList = []
    zoomList.append(img)
    (x1,y1,x2,y2) = img.getbbox()
    gridX = x2/4
    gridY = y2/4
    # 
    # Creating cropped versions of images 
    # Nine 2x2 crops
    for x in range (0,3):
        for y in range(0,3):
            xLeft = y*gridX
            xRight = (y+2)*gridX
            yTop = x*gridY
            yBottom = (x+2)*gridY
            im = img.crop((xLeft, yTop, xRight, yBottom))

            index += 1
            zoomList.append(im)

    # Four 3x3 crops
    for x in range (0,2):
        for y in range(0,2):
            # print("Working")
            xLeft = y*gridX
            xRight = (y+3)*gridX
            yTop = x*gridY
            yBottom = (x+3)*gridY
            im = img.crop((xLeft, yTop, xRight, yBottom))
            index += 1
            zoomList.append(im)

    zoomList.append(im)
    return zoomList

#
# Takes an image and creates nine non-overlapping cropped versions of it
#
def createCrops(img):
    zoomList = []
    zoomList.append(img)
    (x1,y1,x2,y2) = img.getbbox()

    # Nine cropped images
    box_width = x2 / 3
    box_height = y2 / 3
    index = 0
    for x in range(0,3):
        for y in range(0,3):
            left= x * box_width
            right = (x+1) * box_width
            top = y * box_height
            bot = (y+1) * box_height
            im = img.crop((left, top, right,bot))
            zoomList.append(im)
    return zoomList


#
# Takes in a dataset and output folder path to label all images within with crops based on the guessImage and createCrop functions
# Populates the output folder with new labelled data
#
def determineOptimalZoom(dataset, output_folder):
    img, lab = train_dataset[0]
    index = 1
    
    with open('cat_labels.txt') as f:
        classes = [line.strip() for line in f.readlines()]
        # print(classes)

    for i in range(0, len(dataset)):
        startingIndex = index
        
        img, label = dataset[i]
        zoomList = createCrops(img)

        # Determining the optimal crop
        index += 1
        dispIndex = startingIndex
        topScore = -1
        topIndex = -1
        topFound = False
        for image in zoomList:
            category_name, score = guessImage(image)
            cat_label = False
            if category_name in classes:
                cat_label = True
            if cat_label and score > topScore:
                topScore = score
                topIndex = dispIndex
                topFound = True
            imageLabel = output_folder + str(dispIndex) + ".jpg"
            image.save(imageLabel)
            index += 1
            dispIndex += 1
        
        print("Top Score ", topScore, " Index ", topIndex)
        topLabel = output_folder + str(topIndex) + ".jpg"
        rightLabel = output_folder + str(topIndex) + "R.jpg"
        if topFound:
            os.rename(topLabel, rightLabel)
        

OUTPUT_FOLDER_PATH = "PetImages/RevisitedCat/"
determineOptimalZoom(train_dataset, output_folder=OUTPUT_FOLDER_PATH)


#Separates Optimal and Suboptimal images into separate class folders
src_path = "PetImages/RevisitedCat"
dest_correct = "PetImages/RevisitedTraining/Correct"
dest_incorrect = "PetImages/RevisitedTraining/Incorrect"

for f in os.listdir(src_path):
    if str(f).endswith("R.jpg"):
        shutil.copy(os.path.join(src_path, f), dest_correct)
    else:
        shutil.copy(os.path.join(src_path, f), dest_incorrect)


In [None]:
# Using a test set to create cropped images
TEST_DATA_PATH = "PetImages/TestCat/"
test_dataset = datasets.ImageFolder(TEST_DATA_PATH, transform=transform)
print(len(test_dataset))
for i in range (0, len(test_dataset)):
    j = 0
    img, lab = test_dataset[i]
    zoomList = createCrops(img)
    for j in range (0, len(zoomList)):
        image = zoomList[j]
        name = "PetImages/TestCropped/" + str(i) + "_"+str(j)+".jpg"
        image.save(name)
