In [1]:
from PIL import Image, ImageFont, ImageDraw 

In [None]:
# pipenv install opencv-python
# sudo apt-get update
# sudo apt-get install libgl1-mesa-glx

In [2]:
import torch
from torch.utils.data import Dataset
from glob import glob
import os
import pandas as pd
from torchvision import io
from enum import Enum

import torchvision.transforms as transforms
class Devices(Enum):
    iphone_11 = 'iphone 11'
    samsung_s21 = 'samsung S21'
    all = '*'

class Environments(Enum):
    indoor = 'Indoor'
    outdoor = 'Outdoor'
    all = '*'

class Elements(Enum):
    om = 'OM'
    p = 'P'
    k = 'K'

class Preprocessing(Enum):
    training = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(350),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    inferencing  = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(350),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])


class SoilDataset_bigset(Dataset):
    def __init__(self, dataset_name:str, element:Elements, device:Devices, 
                 environment:Environments, 
                 preprocessing:Preprocessing, 
                 clip_target:bool=False, 
                 normalize_target:bool=False):
        self.element:Elements = element
        self.clip_target:bool = clip_target
        self.normalize_target:bool = normalize_target
        
        # Set max value for clipping and normalizing
        self.max_value:float = 0
        if(self.element == Elements.om):
            self.max_value = 8.0
        elif(self.element == Elements.p):
            self.max_value = 1000.0
        elif(self.element == Elements.k):
            self.max_value = 1500.0

        # BasePath of the dataset
        base_path:str = "../dataset"
        dataset_path:str = os.path.join(base_path, dataset_name)
        if(os.path.exists(dataset_path) == False):
            raise ValueError(f"Path={dataset_path} is not exist. Execution path={os.getcwd()}")

        # Inside this path there must be a list of folders arange by mobile phone. Use device enum.
        # Inside those mobile phone are 2 folders indicate the environment the image was taken in. Use environment enum.
        image_folder = os.path.join(dataset_path, element.value, device.value, environment.value)
        self.imgs = glob(os.path.join(image_folder,'*/*'))
        print(f"Found {len(self.imgs)} images in {image_folder}.")

        # Load csv file for lookup the target value
        target_path:str = os.path.join(dataset_path,element.value,'meta.csv')
        self.target_df = pd.read_csv(target_path, index_col='id')

        self.signature = os.path.join(element.value,device.value,environment.value)
        self.preprocessing = preprocessing.value

    def get_target(self, img_path:str) -> float:
        assert len(img_path.split('/')) == 8, f"Expect img_path to have 8 folders but got {img_path=}"
        target_id = int(img_path.split('/')[6])
        target_value = float(self.target_df.loc[target_id].iloc[0]) # type:ignore
        if(self.clip_target and (target_value > self.max_value)):
            target_value = self.max_value
        if(self.normalize_target):
            target_value = target_value / (self.max_value)
        return float(target_value)
        
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_path = self.imgs[idx]
        y = self.get_target(img_path=img_path)
        y = torch.tensor(y)
        X = io.read_image(img_path)
        if self.preprocessing:
            X = self.preprocessing(X)
        return X.float(), y.float(), img_path

In [12]:
dataset = SoilDataset_bigset(dataset_name="workshop", 
                             element=Elements.om, 
                             device=Devices.all, 
                             environment=Environments.outdoor, 
                             preprocessing=Preprocessing.inferencing, 
                             clip_target=True, normalize_target=True)

Found 200 images in ../dataset/workshop/OM/*/Outdoor.


In [13]:
from torchvision import models
from torchvision.models import AlexNet_Weights

model = models.alexnet(weights=AlexNet_Weights.DEFAULT, progress=True)
model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=1, bias=True)
model.eval

<bound method Module.eval of AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, ou

In [31]:
def show_image(img:Image.Image, label:str, predict:str, picname:str):
    draw = ImageDraw.Draw(img)
    font = ImageFont.load_default()
    # font.
    import cv2
    font = ImageFont.load_default(size=100)
    draw.text((10, 10),f"Label:{label}",(0,0,0), font=font)
    draw.text((10, 40 + font_size),f"Predict:{predict}",(0,0,0), font=font)
    img.show(picname)

In [32]:
for x,y,pic_name in dataset:
    # print(pic_name)
    yhat = model(x.reshape((1,3,224,224))).detach()[0][0]
    img = Image.open(f'{pic_name}')
    pic_name = os.path.join(pic_name)
    show_image(img, y.numpy(), yhat.numpy(), pic_name)
    break