In [1]:
import numpy as np
import pandas as pd
import os
import random
import time
import math
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import gc

import torch
import torchvision
import torch.nn as nn
import torchinfo as info
import torchvision.transforms as transforms
import torchvision.transforms.functional as functional
from torchvision.transforms import v2

In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
set_seed(500)

In [3]:
wd = os.getcwd()
device = torch.device('cuda')

## Model

In [4]:
from torchvision.models import convnext_base, ConvNeXt_Base_Weights

In [5]:
model = convnext_base(weights = ConvNeXt_Base_Weights.IMAGENET1K_V1).to(device)
process = ConvNeXt_Base_Weights.IMAGENET1K_V1.transforms()

In [6]:
for param in model.parameters():
    param.requires_grad = False
layers = []
layers.append(nn.Linear(in_features = 1024, out_features = 256, bias = True))
layers.append(nn.ReLU())
layers.append(nn.Linear(in_features = 256, out_features = 9, bias = True))
model.classifier[2] = nn.Sequential(*layers).to(device)

## Load and Process Data

In [7]:
class WeedsDataset():
    def __init__(self, labels, img_dir, transform = None):
        self.labels = pd.read_csv(labels)
        self.img_dir = img_dir
        self.transform = transform
        self.classes = self.labels[['Label','Species']].drop_duplicates().sort_values(by = 'Label').reset_index(drop = True)['Species']
        self.data = []
        self.train = []
        self.validation = []
        self.test = []
        
    # Load entire dataset into memory
    def __loaddata__(self):
        for row in self.labels.itertuples():
            filename = row.Filename
            label = torch.tensor(row.Label)
            img_path = os.path.join(self.img_dir, filename)
            image = torchvision.io.read_image(img_path)
            self.data.append([image,label])
        del self.labels
        print('Data has been loaded')

    # Split into train/validation/test
    def __split__(self):
        random.shuffle(self.data)
        n = len(self.data)
        n_train = round(n * 0.8)
        n_valid = round(n * 0.1)
        
        self.train = self.data[:n_train]
        self.valid = self.data[n_train:n_train + n_valid]
        self.test = self.data[n_train + n_valid:]
        del self.data
        gc.collect()
        
        print('Data has been split')
        print('Size of training set is {}'.format(len(self.train)))
        print('Size of validation set is {}'.format(len(self.valid)))
        print('Size of test set is {}'.format(len(self.test)))

    # Augment training data
    def __augment__(self, frac):
        aug = []
        color_jitter = transforms.ColorJitter(brightness = 0.3, contrast = 0.2, saturation = 0.3, hue = 0.0)
        n = round(len(self.train) * frac / 3)
        
        rotate_90 = random.sample(self.train, n)
        for row in rotate_90:
            image = row[0].clone()
            image = transforms.functional.rotate(image, 90)
            image = color_jitter(image)
            row[0] = image
        
        rotate_180 = random.sample(self.train, n)
        for row in rotate_180:
            image = row[0].clone()
            image = transforms.functional.rotate(image, 180)
            image = color_jitter(image)
            row[0] = image
        
        rotate_270 = random.sample(self.train, n)
        for row in rotate_270:
            image = row[0].clone()
            image = transforms.functional.rotate(image, 270)
            image = color_jitter(image)
            row[0] = image
     
        self.train += rotate_90 + rotate_180 + rotate_270  
        print('Data has been augmented')
        print('Size of training set is {}'.format(len(self.train)))

    # Apply model transformation
    def __apply__(self):
        for i in range(len(self.train)):
            self.train[i][0] = self.transform(self.train[i][0])
        for i in range(len(self.valid)):
            self.valid[i][0] = self.transform(self.valid[i][0])
        for i in range(len(self.test)):
            self.test[i][0] = self.transform(self.test[i][0]) 
        print('Data has been transformed')

In [8]:
# Run functions to load and process data
weeds = WeedsDataset(rf'{wd}\labels.csv', rf'{wd}\images', process)
weeds.__loaddata__()
weeds.__split__()
weeds.__augment__(0.6)
weeds.__apply__()

Data has been loaded
Data has been split
Size of training set is 14007
Size of validation set is 1751
Size of test set is 1751
Data has been augmented
Size of training set is 22410
Data has been transformed


## Functions

In [9]:
def get_batch(dataset, current_index, batch_size):
    images,labels = zip(*dataset[current_index:current_index+batch_size])
    return torch.stack(images, dim = 0).to(device), torch.stack(labels).to(device)
    
def train_one_epoch(model, dataset, batch_size):
    batches = math.floor(len(dataset.train)/batch_size)
    current_index = 0
    running_loss = 0
    random.shuffle(dataset.train)

    for i in range(batches):
        inputs,labels = get_batch(dataset.train, current_index, batch_size)
        current_index += batch_size
        optimizer.zero_grad(set_to_none=True)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.cpu().item()

    remainder = len(dataset.train)%batch_size
    if remainder != 0:
        inputs,labels = get_batch(dataset.train, current_index, batch_size)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.cpu().item()
        batches += 1
    
    return running_loss/batches

def validate(model, dataset, batch_size):
    batches = math.floor(len(dataset.valid)/batch_size)
    current_index = 0
    running_loss = 0
    
    for i in range(batches):
        inputs,labels = get_batch(dataset.valid, current_index, batch_size)
        current_index += batch_size
        output = model(inputs)
        running_loss += loss_fn(output, labels).cpu().item()

    remainder = len(dataset.valid)%batch_size
    if remainder != 0:
        inputs,labels = get_batch(dataset.valid, current_index, batch_size)
        output = model(inputs)
        running_loss += loss_fn(output, labels).cpu().item()
        batches += 1
    
    return running_loss/batches

def train(epochs, model, dataset, batch_size, folder_name, epoch_number, save):
    best_vloss = 1000
    current_vloss = 0
    epoch_counter = epoch_number + 1
    start = time.perf_counter()
    
    if not os.path.exists('{}\{}'.format(wd,folder_name)):
        os.makedirs('{}\{}'.format(wd,folder_name))
        
    for epoch in range(epochs):
        print('EPOCH {}:'.format(epoch_counter))
        model.train()
        avg_loss = train_one_epoch(model, dataset, batch_size)
        
        model.eval()
        with torch.no_grad():
            current_vloss = validate(model, dataset, batch_size)

        print('LOSS train {} valid {}'.format(avg_loss, current_vloss))
        if current_vloss <= best_vloss:
                best_vloss = current_vloss
                print('New best validation loss')

        print('Lr is currently {}'.format(lr_scheduler.get_last_lr()))
        lr_scheduler.step(current_vloss)

        if epoch_counter > epochs - save:
            model_path = '{}\model_{}'.format(folder_name, epoch_counter)
            torch.save(model.state_dict(), model_path)
            
        epoch_counter += 1

    end = time.perf_counter()
    print('Average time taken per epoch is {}s'.format((end - start)/epochs))
    return

In [10]:
# Formula for F1: (2 * FP) / (2 * TP + FP + FN) OR 2 * (Precision * Recall) / (Precision + Recall)
# Formula for Precision: TP / (TP + FP)
# Formula for Recall: TP / (TP + FN)
def metrics(TP,FP,FN):
    if TP > 0:
        precision = TP / (TP + FP)
        recall = TP / (TP + FN)
        F1  = TP / (TP + 0.5 * (FP + FN))
    else:
        precision = 0
        recall = 0
        F1  = 0
    return precision, recall, F1
    
def evaluate(model, dataset, batch_size):
    start = time.perf_counter()
    total = 0
    correct = 0
    batches = math.floor(len(dataset.test) / batch_size)
    current_index = 0
    model.eval()
    
    all_preds = []
    all_labels = []

    for i in range(batches):
        inputs, labels = get_batch(dataset.test, current_index, batch_size)
        current_index += batch_size
        prediction = torch.argmax(model(inputs), dim=1)
        
        correct += sum(prediction == labels).item()
        total += batch_size

        all_preds.extend(prediction.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    remainder = len(dataset.test) % batch_size
    if remainder != 0:
        inputs, labels = get_batch(dataset.test, current_index, remainder)
        prediction = torch.argmax(model(inputs), dim=1)
        
        correct += sum(prediction == labels).item()
        total += remainder

        all_preds.extend(prediction.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    accuracy = 100 * correct / total

    # Compute Metrics for individual classes
    num_classes = len(dataset.classes) 
    conf_matrix = confusion_matrix(all_labels, all_preds, labels=np.arange(num_classes))
    print(conf_matrix)

    for i in range(num_classes):
        TP = conf_matrix[i, i]  # True Positives
        FP = np.sum(conf_matrix[:, i]) - TP  # False Positives
        FN = np.sum(conf_matrix[i, :]) - TP  # False Negatives
        precision,recall,F1 = metrics(TP, FP, FN)
        print(f"Precision for class {dataset.classes[i]}: {precision:.4f}")
        print(f"Recall for class {dataset.classes[i]}: {recall:.4f}")
        print(f"F1-score for class {dataset.classes[i]}: {F1:.4f}")
    print('Total Accuracy is {}%'.format(accuracy))
    
    end = time.perf_counter()
    print('Time taken is {}'.format(end - start))

## Training with only classifier unfrozen

In [11]:
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.6)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.5, patience = 15)

In [12]:
train(100,model,weeds,50,'convnext',0,50)

EPOCH 1:
LOSS train 1.3151770301810353 valid 0.9443850848409865
New best validation loss
Lr is currently [0.01]
EPOCH 2:
LOSS train 0.9444803911487349 valid 0.7382359165284369
New best validation loss
Lr is currently [0.01]
EPOCH 3:
LOSS train 0.8059539428532522 valid 0.6703195091750886
New best validation loss
Lr is currently [0.01]
EPOCH 4:
LOSS train 0.7222074720933868 valid 0.5721244696113799
New best validation loss
Lr is currently [0.01]
EPOCH 5:
LOSS train 0.672411685133299 valid 0.5119041088554594
New best validation loss
Lr is currently [0.01]
EPOCH 6:
LOSS train 0.6272963405584705 valid 0.464963318573104
New best validation loss
Lr is currently [0.01]
EPOCH 7:
LOSS train 0.6002127816921353 valid 0.4481567301683956
New best validation loss
Lr is currently [0.01]
EPOCH 8:
LOSS train 0.5780853566056104 valid 0.41632111246387166
New best validation loss
Lr is currently [0.01]
EPOCH 9:
LOSS train 0.5549104350114984 valid 0.40299441417058307
New best validation loss
Lr is currently

## Evaluation

In [15]:
model_path = os.path.join(wd,'convnext','model_84')
model.load_state_dict(torch.load(model_path, weights_only = True))

<All keys matched successfully>

In [16]:
with torch.no_grad():
    evaluate(model, weeds, 50)

[[ 83   1   0   1   1   1   1   5  13]
 [  4  88   0   0   0   0   1   1   6]
 [  0   0  95   0   2   0   0   0   1]
 [  1   1   3  78   0   0   1   2  11]
 [  1   0   1   2 108   0   0   0   4]
 [  1   0   0   0   0  93   0   0   8]
 [  0   1   0   0   0   0  85   1   7]
 [  9   2   0   0   1   1   0  76   3]
 [  6   4   2   2   7   5   9   8 903]]
Precision for class Chinee apple: 0.7905
Recall for class Chinee apple: 0.7830
F1-score for class Chinee apple: 0.7867
Precision for class Lantana: 0.9072
Recall for class Lantana: 0.8800
F1-score for class Lantana: 0.8934
Precision for class Parkinsonia: 0.9406
Recall for class Parkinsonia: 0.9694
F1-score for class Parkinsonia: 0.9548
Precision for class Parthenium: 0.9398
Recall for class Parthenium: 0.8041
F1-score for class Parthenium: 0.8667
Precision for class Prickly acacia: 0.9076
Recall for class Prickly acacia: 0.9310
F1-score for class Prickly acacia: 0.9191
Precision for class Rubber vine: 0.9300
Recall for class Rubber vine: 0