### Paper parameters 
- 60% train, 40% split into val and test
- Images randomly cropped to 300x300 pixels
- 4 classes for classification
- cross entropy loss
- 15 epochs
- 0.001 learning rate (start, then annealing cosine scheduler)
- Sliding window of 300x300 with a stride of 50 pixels
    - Softmax per window. Highest class probability is one vote if probability higher than threshold.
    - Final prediction is the condition with highest number of votes across all windows
    - Probability thresholds of 0.7. No probability, then data is dropped 

### Paper results
- field/ no field classifier: Finally, training a CNN to classify
field/not field improved precision to 0.98 and lowered recall
slightly to 0.95 — a worthwhile trade-of
- "Maize had the lowest F1 score under both WebCC
and iNaturalist (62% and 68%, respectively), followed by
sugarcane (62% and 76%). The low performance is likely
due to the more visual similarities between the two crops at
the early-growth stage as well as their similar height. Meanwhile, rice has short and bright green leaves and cassava has
a palmate leaf structure and is grown more separately, both
more visually distinctive from the street".
    - And we discriminate from subtle differences in leaf color and amount for very different tree species.
- "The top performing model trained on the Expert labeled dataset achieved 82% accuracy when directly classifying the whole image, 90% with sliding windows but no
MHP threshold, and 93% with a 0.90 MHP threshold"

### Useful references
+ Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
+ Croptypes paper: https://arxiv.org/pdf/2309.05930.pdf
+ Input needed for resnet50: https://pytorch.org/hub/pytorch_vision_resnet/


In [59]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import v2 as tranforms_v2
from torch.utils.data import DataLoader, random_split, Subset
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np
from collections import Counter
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from torch.utils.tensorboard import SummaryWriter

# Additional Setup to use Tensorboard
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [9]:
# Helper function
def get_class_distribution(dataset):
    count_dict = dict()
    
    for input, label in dataset:
        count_dict[label] = count_dict.get(label, 0) + 1
    return count_dict


Use ordered_trees_classified.zip to get te correct folder order to work with the code. The images will be cropped to allow for 9 windows. Horizontal stride 50 pixels and vertical stride 36 pixels. The smaller stride is due to small image height which either needs very large or somewhat small stride. 

Due to the cropped size, the image can be cropped againrandomly during training of the model whilst still having the window in a place where some windows will be during testing. 

In [40]:
transform = transforms.Compose([
    transforms.CenterCrop((300, 324)),  # Cut out the image to 300x324 | To be further cut up later
    transforms.ToTensor(),           # Convert the image to a PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image
])
data_path = "trees_classified"
label_transform = {0:"Matig", 1:"Redelijk", 2:"Slecht", 3:"Goed", 4:"Dood", 5:"Zeer-Slecht"}  # Dataset sorted on most samples first

In [41]:
dataset = torchvision.datasets.ImageFolder(root=data_path, transform=transform)

In [42]:
# See how unbalanced the classes are before balancing
get_class_distribution(dataset)

{0: 9553, 1: 8104, 2: 947, 3: 289, 4: 116, 5: 87}

In [43]:
# Balance the classes by randomly samling class_limit samples from each class

class_limit = 1000
num_classes = 3  # specify how many classes we use. This determines output dimension of the model.

#idx_class1 = [i for i, label in enumerate(dataset.targets) if label == 1]  # Add one if samples are weighted based on class imbalance
idx_class0 = [i for i, label in enumerate(dataset.targets) if label == 0]
idx_class1 = [i for i, label in enumerate(dataset.targets) if label == 1]
idx_class2 = [i for i, label in enumerate(dataset.targets) if label == 2]


np.random.shuffle(idx_class0)
np.random.shuffle(idx_class1)
idx_class0_limit = idx_class0[:class_limit]
idx_class1_limit = idx_class1[:class_limit]
# print(len(idx_class1_limit))
# print(len(idx_class0))

idx_dataset_limited = np.concatenate((idx_class0_limit, idx_class1_limit, idx_class2))
# print(len(idx_dataset_limited))

balanced_dataset = Subset(dataset, idx_dataset_limited)

{0: 1000, 1: 1000, 2: 947}

In [71]:
train_size = int(0.8 * len(balanced_dataset))
test_size = len(balanced_dataset) - train_size
train_dataset, test_dataset = random_split(balanced_dataset, [train_size, test_size])

# Make some into validation to evaluate the model before to find out which weight perform best during training. 
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

In [72]:
# Evaluate class balance per set
print(f"Overall dataset balance {get_class_distribution(balanced_dataset)}")
print(f"Train dataset balance {get_class_distribution(train_dataset)}")
print(f"Test dataset balance {get_class_distribution(test_dataset)}")
print(f"Val dataset balance {get_class_distribution(val_dataset)}")

Overall dataset balance {0: 1000, 1: 1000, 2: 947}
Train dataset balance {1: 730, 2: 665, 0: 726}
Test dataset balance {2: 211, 0: 189, 1: 190}
Val dataset balance {2: 71, 1: 80, 0: 85}


In [73]:
batch_size = 32
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)

In [119]:
# Testing transforms and sizes
random_crop = transforms.RandomCrop((224, 224))

for data in train_loader:
    samples, labels = data
    print("Minibatch all at once =", random_crop(samples).size())
    #sample_list = []
    for sample in samples:
        #sample_list.append(sample)
        print("One image before and after")
        print("Original size =", sample.size())
        sample_trans = tranforms_v2.functional.crop(sample, 0,138,220, 220)
        print("transformed size =",sample_trans.size())
        break
    #sample_list = 32 * [torch.Tensor(1)]
    #stacked_tensor = torch.cat(sample_list)
    #print("Stacked size =", stacked_tensor.size())
    break

# Mode did not work correctly for multiple mode values. ChatGPT helped me find a workaround. 
unique_values, counts = torch.unique(torch.tensor([1, 0, 0, 2, 2, 2]), return_counts=True)
max_freq = torch.max(counts)
modes = unique_values[counts == max_freq]
if len(modes) > 1:
    print("\nThe input tensor has multiple modes:", modes)
else:
    print("\nThe input tensor has a single mode:", modes.item())

Minibatch all at once = torch.Size([32, 3, 224, 224])
One image before and after
Original size = torch.Size([3, 300, 324])
transformed size = torch.Size([3, 220, 220])

The input tensor has a single mode: 2
8


In [35]:
resnet50 = models.resnet50(weights='DEFAULT')


In [36]:
random_crop = transforms.RandomCrop((224, 224))

def train(train_loader, net, optimizer, criterion, device):
    """
    Trains network for one epoch in batches.

    Args:
        train_loader: Data loader for training set.
        net: Neural network model.
        optimizer: Optimizer (e.g. SGD).
        criterion: Loss function (e.g. cross-entropy loss).
    """

    avg_loss = 0.
    correct = 0
    total = 0

    net.train()

    # iterate through batches
    for i, data in enumerate(train_loader):
        inputs, labels = random_crop(data[0]), data[1]  # Crop images randomly for testing
        inputs, labels = inputs.to(device), labels.to(device)

        # zero teh parameters of optimizer
        optimizer.zero_grad()

        # fwd + back + opti
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # keep track of loss and acc
        avg_loss += loss
        _, predicted = torch.max(outputs.data,1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    return avg_loss/len(train_loader), 100 * correct/total

def test(test_loader, net, criterion, device):
    avg_loss = 0.
    correct = 0
    total = 0
    labels_pred = []
    labels_true = []
    net.eval()

    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data

            # Apply sliding window on every image of the batch
            for image, label in zip(inputs, labels):
                image_windows = []
                for t in [0, 36, 72]:  # top pixels coordinate
                    for l in [138, 188, 238]:  # left pixels coordinate
                         image_windows.append(tranforms_v2.functional.crop(image, t, l, 224, 224))
                window_tensor = torch.stack(image_windows)
                labels_tensor = torch.cat(9 * [torch.Tensor(label)])
                
                window_tensor, labels_tensor = window_tensor.to(device), labels_tensor.to(device)

                outputs = net(window_tensor)
                _, predicted = torch.max(outputs.data, 1)
                
                # Calculate modes and their occurance (torch.modes only gives lowest mode)
                unique_values, counts = torch.unique(predicted, return_counts=True)
                max_freq = torch.max(counts)
                modes = unique_values[counts == max_freq]
                
                # Check if there is one mode, only then there is a majority vote and valid result
                if len(modes) = 1:
                    aggregated_label = modes.item()
                    
                    # Average continuous loss over the windows and later over images
                    loss = criterion(outputs, labels_tensor)
                    avg_loss += loss  # now avg_loss is per image, not per batch. Devide by total instead of batch number in the end. 
                
                    total += 1
                    correct += int(aggregated_label == label)
                    labels_pred.append(aggregated_label)
                    labels_true.append(label)
    
    return avg_loss/total, 100 * correct/total, labels_pred, labels_true

In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [38]:
writer = SummaryWriter(log_dir="./runs_condition/", flush_secs=300)

resnet50.fc = nn.Linear(resnet50.fc.in_features, num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet50.fc.parameters(), lr=0.001)

num_epochs = 15
resnet50.to(device)

# Set the number of epochs to for training
epochs = 100

patience = 15

train_acc_best = 0
patience_cnt = 0

for epoch in tqdm(range(epochs)):  # loop over the dataset multiple times
    # Train on data
    train_loss, train_acc = train(train_loader, resnet50, optimizer, criterion, device)

    # Test on data
    val_loss, val_acc,_,_ = test(val_loader, resnet50, criterion, device)

    # Write metrics to Tensorboard
    writer.add_scalars("Loss", {'Train': train_loss, 'Test':val_loss}, epoch)
    writer.add_scalars('Accuracy', {'Train': train_acc,'Test':val_acc} , epoch)


    if val_acc > val_acc_best:
      val_acc_best = val_acc
      patience_cnt = 0
      best_model_wts = resnet50.state_dict()

    else:
      patience_cnt += 1
      if patience_cnt == patience:
        break
    # print(f"Current loss {train_loss} at epoch {epoch}")


print('Finished Training')
writer.flush()
writer.close()

  0%|                                                                                           | 0/15 [02:20<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# Open Tensorboard
%tensorboard --logdir runs_condition/

# For local users only: uncomment the last line, run this cell once and wait for
# it to time out, run this cell a second time and you should see the board.
# %tensorboard --logdir runs/ --host localhost

In [None]:
resnet50.load_state_dict(best_model_wts)

In [None]:
torch.save(resnet50.state_dict(), "resnet50_model.pt")

In [None]:
true_labels = []
predicted_labels = []

resnet50.eval()


# Iterate through the test data
test_loss, test_acc, predicted_labels, true_labels = test(val_loader, resnet50, criterion, device)

# Print the evaluation methods
print(classification_report(true_labels, predicted_labels))

print(confusion_matrix(true_labels, predicted_labels))

print(f1_score(true_labels, predicted_labels))  # Per class F1

print(f1_score(true_labels, predicted_labels, average='micro'))  # Get global F1, not per class average

print(f"Accuracy: {test_acc}, Average loss: {test_loss}")