In [1]:
import os
from pathlib import Path
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import (accuracy_score, precision_recall_fscore_support,
                             confusion_matrix, classification_report)
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
#set up dataset roots
DATA_ROOT = Path("data/tree")
TRAIN_DIR = DATA_ROOT / "train"
TEST_DIR = DATA_ROOT / "test"
VAL_DIR = DATA_ROOT / "val"

IMG_SIZE = 224
BATCH_SIZE = 64
NUM_WORKERS = 4
NUM_EPOCHS = 20
LEARNING_RATE = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

Device: cuda


In [4]:
#set up transforms and dataloaders
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(0.08, 0.08, 0.08, 0.02),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

train_ds = ImageFolder(TRAIN_DIR, transform = train_transform)
val_ds = ImageFolder(VAL_DIR, transform=val_transform)
test_ds = ImageFolder(TEST_DIR, transform=val_transform)

train_loader = DataLoader(train_ds, batch_size = BATCH_SIZE, shuffle = True, num_workers= NUM_WORKERS)
val_loader = DataLoader(val_ds, batch_size = BATCH_SIZE, shuffle = False, num_workers= NUM_WORKERS)
test_loader = DataLoader(test_ds, batch_size = BATCH_SIZE, shuffle = False, num_workers= NUM_WORKERS)


class_names = train_ds.classes
num_classes = len(class_names)
print("Classes:", class_names)
print("Class Count:", num_classes)
print("Sizes of Train:", len(train_ds), "Val:", len(val_ds), "Test:", len(test_ds))

Classes: ['Acer palmatum', 'Cedrus deodara', 'Celtis sinensis', 'Cinnamomum camphora (Linn) Presl', 'Elaeocarpus decipiens', 'Flowering cherry', 'Ginkgo biloba', 'Koelreuteria paniculata', 'Lagerstroemia indica', 'Liquidambar formosana', 'Liriodendron chinense', 'Magnolia grandiflora L', 'Magnolia liliflora Desr', 'Michelia chapensis', 'Osmanthus fragrans', 'Photinia serratifolia', 'Platanus', 'Prunus cerasifera f. atropurpurea', 'Salix babylonica', 'Sapindus saponaria', 'Styphnolobium japonicum', 'Triadica sebifera', 'Zelkova serrata']
Class Count: 23
Sizes of Train: 3850 Val: 482 Test: 472
