In [1]:
import numpy as np
import pandas as pd
import os
import random
import time
import math
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import confusion_matrix
import csv

import torch
import torchvision
import torch.nn as nn
import torchinfo as info
import torchvision.transforms.v2 as v2

In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
set_seed(600)
device = torch.device('cuda')

In [3]:
wd = os.getcwd()
labels_dir = os.path.join(wd,'Labels')
img_dir = os.path.join(wd,'Sets')
record_path = os.path.join(wd,'Records')
model_path = os.path.join(wd,'Models')
if not os.path.exists(record_path):
    os.makedirs(record_path)

## Model

In [4]:
from torchvision.models import swin_s, Swin_S_Weights

In [5]:
model = swin_s(weights = Swin_S_Weights.IMAGENET1K_V1).to(device)
transform = Swin_S_Weights.IMAGENET1K_V1.transforms()
for param in model.parameters():
    param.requires_grad = False

In [6]:
layers = []
layers.append(nn.Linear(in_features = 768, out_features = 256, bias = True))
layers.append(nn.Dropout(p = 0.15))
layers.append(nn.ReLU())
layers.append(nn.Linear(in_features = 256, out_features = 256, bias = True))
layers.append(nn.ReLU())
layers.append(nn.Linear(in_features = 256, out_features = 9, bias = True))
model.head = nn.Sequential(*layers).to(device)

In [9]:
for param in model.features[7][0].parameters():
    param.requires_grad = True
for param in model.features[7][1].parameters():
    param.requires_grad = True
for param in model.norm.parameters():
    param.requires_grad = True

In [None]:
#info.summary(model)

## Dataset

In [11]:
class WeedsDataset():
    def __init__(self, labels_dir, img_dir, transform):
        self.train_labels = pd.read_csv(os.path.join(labels_dir,'train.csv'))
        self.valid_labels = pd.read_csv(os.path.join(labels_dir,'valid.csv'))
        self.test_labels = pd.read_csv(os.path.join(labels_dir,'test.csv'))
        self.img_dir = img_dir
        self.classes = None
        self.counts = {'train':{},'valid':{},'test':{}}
        self.train = []
        self.valid = []
        self.test = []
        self.augmented = []
        self.transform = transform
        
    # Load dataset
    def __loaddata__(self):
        self.classes = self.train_labels[['Label','Class']].drop_duplicates().sort_values(by = 'Label').reset_index(drop = True)['Class']
        for key in self.classes.keys():
            self.counts['train'][key] = 0
            self.counts['valid'][key] = 0
            self.counts['test'][key] = 0

        for row in self.train_labels.itertuples():
            filename = row.Filename
            label = row.Label
            self.counts['train'][row.Label] += 1
            img_path = os.path.join(self.img_dir,'train',filename)
            image = torchvision.io.read_image(img_path)
            self.train.append([image,label])   
        for row in self.valid_labels.itertuples():
            filename = row.Filename
            label = row.Label
            self.counts['valid'][row.Label] += 1
            img_path = os.path.join(self.img_dir,'valid',filename)
            image = torchvision.io.read_image(img_path)
            self.valid.append([image,label])   
        for row in self.test_labels.itertuples():
            filename = row.Filename
            label = row.Label
            self.counts['test'][row.Label] += 1
            img_path = os.path.join(self.img_dir,'test',filename)
            image = torchvision.io.read_image(img_path)
            self.test.append([image,label])

        del self.train_labels,self.valid_labels,self.test_labels
        print('Data has been loaded')

    def __apply__(self):
        for i in range(len(self.train)):
            image = transform(self.train[i][0])
            label = torch.tensor(self.train[i][1])
            self.train[i] = [image,label]

        for i in range(len(self.valid)):
            image = transform(self.valid[i][0])
            label = torch.tensor(self.valid[i][1])
            self.valid[i] = [image,label]

        for i in range(len(self.test)):
            image = transform(self.test[i][0])
            label = torch.tensor(self.test[i][1])
            self.test[i] = [image,label]
        print('Data has been processed')

In [12]:
weeds = WeedsDataset(labels_dir,img_dir,transform)
weeds.__loaddata__()
weeds.__apply__()

Data has been loaded
Data has been processed


In [13]:
#print(weeds.counts)
#print(weeds.classes)

## Functions

In [14]:
def get_batch(dataset, current_index, batch):
    images,labels = zip(*dataset[current_index:current_index+batch])
    return torch.stack(images, dim = 0).to(device), torch.stack(labels).to(device)
    
def train_one_epoch(model, dataset, batch_size):
    random.shuffle(dataset.train)
    batches = math.floor(len(dataset.train)/batch_size)
    current_index = 0
    running_loss = 0

    for i in range(batches):
        inputs,labels = get_batch(dataset.train, current_index, batch_size)
        current_index += batch_size
        optimizer.zero_grad(set_to_none=True)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.cpu().item()

    remainder = len(dataset.train)%batch_size
    if remainder != 0:
        inputs,labels = get_batch(dataset.train, current_index, remainder)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.cpu().item()
        batches += 1
    
    return running_loss/batches

def validate(model, dataset, batch_size):
    batches = math.floor(len(dataset.valid)/batch_size)
    current_index = 0
    running_loss = 0
    
    for i in range(batches):
        inputs,labels = get_batch(dataset.valid, current_index, batch_size)
        current_index += batch_size
        output = model(inputs)
        running_loss += loss_fn(output, labels).cpu().item()

    remainder = len(dataset.valid)%batch_size
    if remainder != 0:
        inputs,labels = get_batch(dataset.valid, current_index, remainder)
        output = model(inputs)
        running_loss += loss_fn(output, labels).cpu().item()
        batches += 1
    
    return running_loss/batches

def train(epochs, model, dataset, batch_size, folder_name, epoch_number, save, record):
    best_vloss = 1000
    current_vloss = 0
    epoch_counter = epoch_number + 1
    start = time.perf_counter()
    
    if not os.path.exists('{}\{}'.format(model_path,folder_name)):
        os.makedirs('{}\{}'.format(model_path,folder_name))
        
    for epoch in range(epochs):
        print('EPOCH {}:'.format(epoch_counter))
        model.train()
        avg_loss = train_one_epoch(model, dataset, batch_size)
        
        model.eval()
        with torch.no_grad():
            current_vloss = validate(model, dataset, batch_size)

        print('LOSS train {} valid {}'.format(avg_loss, current_vloss))
        if current_vloss <= best_vloss:
                best_vloss = current_vloss
                print('New best validation loss')

        if epoch_counter - epoch_number > epochs - save:
            path = '{}\\{}\\model_{}'.format(model_path, folder_name, epoch_counter)
            torch.save(model.state_dict(), path)

        record.append([epoch_counter,avg_loss,current_vloss])
        epoch_counter += 1

    end = time.perf_counter()
    print('Average time taken per epoch is {}s'.format((end - start)/epochs))
    return

def compute_weights(dataset):
    n_samples = len(dataset.train)
    n_classes = len(dataset.classes)
    weights = torch.zeros(n_classes)
    dict = dataset.counts['train']
    for i in range(n_classes):
        weight = n_samples/(n_classes*dict[i])
        weights[i] = weight
    return weights

In [15]:
# Formula for F1: (2 * FP) / (2 * TP + FP + FN) OR 2 * (Precision * Recall) / (Precision + Recall)
# Formula for Precision: TP / (TP + FP)
# Formula for Recall: TP / (TP + FN)
def metrics(TP,FP,FN):
    if TP > 0:
        precision = TP / (TP + FP)
        recall = TP / (TP + FN)
        F1  = TP / (TP + 0.5 * (FP + FN))
    else:
        precision = 0
        recall = 0
        F1  = 0
    return precision, recall, F1
    
# Use weeds.dataset for dataset
def evaluate(model, dataset, batch_size, classes):
    start = time.perf_counter()
    total = 0
    correct = 0
    batches = math.floor(len(dataset) / batch_size)
    current_index = 0
    model.eval()
    
    all_preds = []
    all_labels = []
    
    for i in range(batches):
        inputs, labels = get_batch(dataset, current_index, batch_size)
        current_index += batch_size
        prediction = torch.argmax(model(inputs),dim = 1)
        
        correct += sum(prediction == labels).item()
        total += batch_size

        all_preds.extend(prediction.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    remainder = len(dataset) % batch_size
    if remainder != 0:
        inputs, labels = get_batch(dataset, current_index, remainder)
        prediction = torch.argmax(model(inputs),dim = 1)
        
        correct += sum(prediction == labels).item()
        total += remainder

        all_preds.extend(prediction.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    accuracy = 100 * correct / total

    # Compute Metrics for individual classes
    num_classes = len(classes) 
    conf_matrix = confusion_matrix(all_labels, all_preds, labels=np.arange(num_classes))
    print(conf_matrix)

    for i in range(num_classes):
        TP = conf_matrix[i, i]  # True Positives
        FP = np.sum(conf_matrix[:, i]) - TP  # False Positives
        FN = np.sum(conf_matrix[i, :]) - TP  # False Negatives
        precision,recall,F1 = metrics(TP, FP, FN)
        print(f"Precision for class {classes[i]}: {precision:.4f}")
        print(f"Recall for class {classes[i]}: {recall:.4f}")
        print(f"F1-score for class {classes[i]}: {F1:.4f}")
    print('Total Accuracy is {}%'.format(accuracy))
    
    end = time.perf_counter()
    print('Time taken is {}'.format(end - start))

## Training

In [16]:
weights = compute_weights(weeds)
print(weights)

tensor([1.7066, 1.8397, 1.8639, 1.8888, 1.8639, 1.9167, 1.8441, 1.8865, 0.2140])


In [17]:
record = [['Epoch','Training loss','Validation loss']]

In [18]:
loss_fn = nn.CrossEntropyLoss(weight=weights).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.6, weight_decay = 0.00001)

In [19]:
train(30,model,weeds,30,'swin1',0,30,record)

EPOCH 1:
LOSS train 1.00493153789314 valid 0.5039573592149605
New best validation loss
EPOCH 2:
LOSS train 0.5217553137201313 valid 0.3526537158857968
New best validation loss
EPOCH 3:
LOSS train 0.41513340881919963 valid 0.26030538894109806
New best validation loss
EPOCH 4:
LOSS train 0.36448889929099243 valid 0.23041003247943975
New best validation loss
EPOCH 5:
LOSS train 0.3117625881624618 valid 0.1967321124117253
New best validation loss
EPOCH 6:
LOSS train 0.27295536753907 valid 0.1972112844429784
EPOCH 7:
LOSS train 0.24587740382574483 valid 0.1546279796383391
New best validation loss
EPOCH 8:
LOSS train 0.22237506148662634 valid 0.1616718540346976
EPOCH 9:
LOSS train 0.21099097887951507 valid 0.1537213812234922
New best validation loss
EPOCH 10:
LOSS train 0.19477486335403213 valid 0.14267908309791552
New best validation loss
EPOCH 11:
LOSS train 0.18178888449118308 valid 0.1581050243733798
EPOCH 12:
LOSS train 0.1698375968501865 valid 0.14004929144314285
New best validation lo

In [23]:
path = os.path.join(model_path,'swin1','model_22')
model.load_state_dict(torch.load(path, weights_only = True))

<All keys matched successfully>

In [24]:
record = record[:23]
optimizer = torch.optim.SGD(model.parameters(), lr = 0.004, momentum = 0.6, weight_decay = 0.00001)
train(18,model,weeds,30,'swin1',22,18,record)

EPOCH 23:
LOSS train 0.0791582308219281 valid 0.12440417118477083
New best validation loss
EPOCH 24:
LOSS train 0.07315790976624043 valid 0.11695349175222534
New best validation loss
EPOCH 25:
LOSS train 0.07690642252529531 valid 0.1277459480314342
EPOCH 26:
LOSS train 0.07319271830191643 valid 0.12677410944420212
EPOCH 27:
LOSS train 0.06622254528001557 valid 0.13365860286827783
EPOCH 28:
LOSS train 0.06654792815045259 valid 0.12906371196830582
EPOCH 29:
LOSS train 0.06710689386681452 valid 0.13043824627490366
EPOCH 30:
LOSS train 0.06159922121764745 valid 0.1308718815300874
EPOCH 31:
LOSS train 0.0626493090589439 valid 0.13349512409985112
EPOCH 32:
LOSS train 0.06317026793372935 valid 0.12088610998047042
EPOCH 33:
LOSS train 0.059206542409896624 valid 0.13173201702183934
EPOCH 34:
LOSS train 0.06160525252701808 valid 0.13098411184984213
EPOCH 35:
LOSS train 0.054729109976928765 valid 0.1361473535450957
EPOCH 36:
LOSS train 0.05555839695110448 valid 0.126241507851203
EPOCH 37:
LOSS tr

In [27]:
with open(os.path.join(record_path,'swin1.csv'), 'w', newline='') as csvAP:
    writer = csv.writer(csvAP)
    writer.writerows(record)

## Evaluation

In [37]:
path = os.path.join(model_path,'swin1','model_24')
model.load_state_dict(torch.load(path, weights_only = True))

<All keys matched successfully>

In [38]:
with torch.no_grad():
    evaluate(model,weeds.test,30,weeds.classes)

[[ 89   4   0   0   0   2   0   3   2]
 [  0 100   0   0   0   0   0   2   1]
 [  0   0 118   0   1   0   0   0   1]
 [  1   0   1  89   0   0   0   0   2]
 [  0   0   1   0 101   0   0   0   0]
 [  1   0   0   0   0 107   0   0   0]
 [  0   0   0   0   0   0 108   0   2]
 [  2   0   0   0   0   0   0 113   4]
 [  3   6   2   6   5   8   3   3 860]]
Precision for class Chinee apple: 0.9271
Recall for class Chinee apple: 0.8900
F1-score for class Chinee apple: 0.9082
Precision for class Lantana: 0.9091
Recall for class Lantana: 0.9709
F1-score for class Lantana: 0.9390
Precision for class Parkinsonia: 0.9672
Recall for class Parkinsonia: 0.9833
F1-score for class Parkinsonia: 0.9752
Precision for class Parthenium: 0.9368
Recall for class Parthenium: 0.9570
F1-score for class Parthenium: 0.9468
Precision for class Prickly acacia: 0.9439
Recall for class Prickly acacia: 0.9902
F1-score for class Prickly acacia: 0.9665
Precision for class Rubber vine: 0.9145
Recall for class Rubber vine: 0

In [None]:
with torch.no_grad():
    evaluate(model,weeds.valid,30,weeds.classes)