In [173]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Subset
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np
from collections import Counter
from sklearn.metrics import classification_report
from PIL import Image
import os
import shutil

In [174]:
# condition of the tree
condition = "Zeer-Slecht"

In [175]:
# transofrm images while loading

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the image to 224x224
    transforms.ToTensor(),           # Convert the image to a PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image
])

In [176]:
data_path = "data/trees/" + condition + "/unfiltered/images/"

In [177]:
resnet_dict = torch.load("resnet18_model.pt")

In [178]:
resnet18 = models.resnet18(pretrained=True)
num_classes = 2
resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)
resnet18.load_state_dict(resnet_dict)

resnet18.eval()

# Define transformations to be applied to your images
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the image to 224x224
    transforms.ToTensor(),           # Convert the image to a PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image
])


In [179]:
data_path = "data/trees/" + condition + "/unfiltered/images/"


In [180]:
predicted_labels = []
filenames_trees = []
for filename in tqdm(os.listdir(data_path)):
    # Check if the file is an image file
    if filename.endswith(".jpg") or filename.endswith(".png"):
        # Load the image
        image_path = os.path.join(data_path, filename)

        if os.path.exists(image_path):

            try:
                image = Image.open(image_path)

                if image.mode != 'RGB':
                    image = image.convert('RGB')

                # Apply transformations
                image = transform(image).unsqueeze(0)  # Add batch dimension
                with torch.no_grad():
                    outputs = resnet18(image)
                    _, predicted = torch.max(outputs.data,1)
                    predicted_labels.append(predicted.item())
                    if predicted == 1:
                        filenames_trees.append(filename)
            except:
                print("Could not open file name: ", filename)
        else:
            print("Skipped file name: ", filename)

100%|██████████| 131/131 [00:04<00:00, 28.48it/s]


In [181]:
len(predicted_labels)

131

In [182]:
len(filenames_trees)

87

In [183]:
filenames_trees[0]

'tree_123.jpg'

In [184]:
destination_folder = "./data/trees_classified/" + condition + "/"

In [185]:
for image in tqdm(filenames_trees):
    shutil.copy(data_path + image, destination_folder)

100%|██████████| 87/87 [00:00<00:00, 5614.17it/s]
