In [None]:
"""  OPTIONAL: Download the strawberries dataset from Roboflow  """
!pip install roboflow

from roboflow import Roboflow
rf = Roboflow(api_key="6kfjjO565pfpxqd6YL4S")
project = rf.workspace("skripsie").project("strawberry.00")
dataset = project.version(15).download("yolov5")

In [None]:
"""  Object detection on Strawberries Dataset.
Libraries  """

import os
import random
import torch
from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image, ImageDraw
from sklearn.model_selection import train_test_split
import matplotlib.patches as patches
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

In [None]:
"""  Import Dataset  """

HOME = os.getcwd()

class MyDataSet(Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        all_imgs = os.listdir(main_dir)
        self.total_imgs = all_imgs

    def __len__(self):
        return len(self.total_imgs)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
        image = Image.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        return tensor_image

train_dataset = MyDataSet(
    f"{HOME}/strawberry.00-15/train/images",
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((560, 560))
    ]))

train_dataloader = DataLoader(
  train_dataset, shuffle=True, batch_size=4
)

In [None]:
"""  Visualize Dataset  """


def plot_bounding_boxes(annotation_files):

    n_files = len(annotation_files)
    fig, axes = plt.subplots(1, n_files, figsize=(20, 20))

    for i, anno in enumerate(annotation_files):
        image, annotation_list = get_image_and_annotations(anno)
        w, h = image.size
        # Display the image
        axes[i].imshow(image)

        for ann in annotation_list:
            obj_cls, x0, y0, x1, y1 = ann
            coordinates = ((x0-x1/2)*w, (y0-y1/2)*h), x1*w, y1*h
            # Create a Rectangle patch
            rect = patches.Rectangle(*coordinates, linewidth=1, edgecolor='r', facecolor='none')
            # Add the patch to the Axes
            axes[i].add_patch(rect)

    plt.show()


def get_image_and_annotations(annotation_file):

    with open(annotation_file, "r") as file:
        annotation_list = file.read().split("\n")
        annotation_list = [x.split(" ") for x in annotation_list]
        annotation_list = [[float(x) for x in y] for y in annotation_list]

    #Get the corresponding image file
    image_file = annotation_file.replace("labels", "images").replace("txt", "jpg")
    assert os.path.exists(image_file)
    #Load the image
    image = Image.open(image_file)
    return image, annotation_list


In [None]:
""" Run Visualization Dataset """


folder_path = f"{HOME}/strawberry.00-15/train/labels"

anns = random.choices(os.listdir(folder_path), k=5)
for i, ann in enumerate(anns):
  anns[i] = os.path.join(folder_path, ann)

plot_bounding_boxes(anns)

In [None]:
""" Visualize Train Loader"""

import matplotlib.pyplot as plt
import numpy as np

def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# get some random training images
train_features = next(iter(train_dataloader))
img = train_features[0].squeeze()
show(img)

In [None]:
"""  DINOv2  """

#Smallest DINOv2 backbone
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
print(summary(dinov2_vits14, (3,560,560)))

#Larger DINOv2 backbone
#dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
#print(summary(dinov2_vitl14, (3,560,560)))

In [None]:
"""DINOv2 inference: Generate feature embeddings from dataset"""

from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')
dinov2_vits14 = dinov2_vits14.to(device)

all_embeddings, all_targets = [], []

with torch.no_grad():
    for images in tqdm(train_dataloader):
        images = images.to(device)
        embedding = dinov2_vits14(images)
        all_embeddings.append(embedding)

all_embeddings = torch.cat(all_embeddings, dim=0)


In [None]:
"""
Understand DINOv2 output:
  each image (391 images in total) is transofrmed
  to a 384 dimensional feature space
"""

print(np.shape(all_embeddings.numpy()))
count = 0
dir_path = f"{HOME}/strawberry.00-15/train/images"
for path in os.scandir(dir_path):
    if path.is_file():
        count += 1
print('files in folder:', count)