In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import random
from datetime import datetime
import time
import math

import torch
import torchvision
import torch.nn as nn
import torchinfo as info
from torch.utils.data import Dataset
from torchvision.transforms import functional
from torchvision.transforms import v2

In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
set_seed(500)

In [3]:
wd = os.getcwd()
device = torch.device('cuda')

## Import Swin_V2 Model

In [4]:
from torchvision.models import swin_v2_t, Swin_V2_T_Weights

In [5]:
swin_model = swin_v2_t(weights=Swin_V2_T_Weights.DEFAULT).to(device)
preprocess = Swin_V2_T_Weights.DEFAULT.transforms()

## Load and Process Data

In [6]:
# Class representing a set
class WeedsDataset(Dataset):
    def __init__(self, labels, img_dir, transform = None):
        self.labels = pd.read_csv(labels)
        self.img_dir = img_dir
        self.transform = transform
        self.classes = self.labels[['Label','Species']].drop_duplicates().sort_values(by = 'Label').reset_index(drop = True)['Species']
        self.data = []
        self.train = []
        self.validation = []
        self.test = []
        
    # Load entire dataset into memory
    def __loaddata__(self):
        for row in self.labels.itertuples():
            filename = row.Filename
            label = torch.tensor(row.Label)
            img_path = os.path.join(self.img_dir, filename)
            image = self.transform(torchvision.io.read_image(img_path))
            self.data.append([image,label])
        del self.labels
        print('Data has been loaded')

    # Split into train/validation/test
    def __split__(self):
        random.shuffle(self.data)
        n = len(self.data)
        n_train = round(n * 0.8)
        n_valid = round(n * 0.1)
        
        self.train = self.data[:n_train]
        self.valid = self.data[n_train:n_train + n_valid]
        self.test = self.data[n_train + n_valid:]
        del self.data
        
        print('Data has been split')
        print('Size of training set is {}'.format(len(self.train)))
        print('Size of validation set is {}'.format(len(self.valid)))
        print('Size of test set is {}'.format(len(self.test)))

    # Augment with rotations and gaussian noise
    def __augment__(self, frac):
        n = round(len(self.train) * frac / 5)
        rotate_90 = random.sample(self.train, n)
        for row in rotate_90:
            image = row[0]
            row[0] = functional.rotate(image, 90)
        rotate_180 = random.sample(self.train, n)
        for row in rotate_180:
            image = row[0]
            row[0] = functional.rotate(image, 180)
        rotate_270 = random.sample(self.train,n)
        for row in rotate_270:
            image = row[0]
            row[0] = functional.rotate(image, 270)
        self.train += rotate_90 + rotate_180 + rotate_270
        gaussian = v2.GaussianNoise(0,0.08,False)
        noise = random.sample(self.train, 2*n)
        for row in noise:
            image = row[0]
            row[0] = gaussian(image)
        self.train += noise
        
        print('Data has been augmented')
        print('Size of training set is {}'.format(len(self.train)))
        
    # Get item from sets, format is (image,label)
    def __getitem__(self, idx, split):
        if split == 'train':
            item = self.train[idx]
        elif split == 'valid':
            item = self.valid[idx]
        elif split == 'test':
            item = self.test[idx]
        return item[0], item[1]

In [7]:
# Run functions to load and process data
weeds = WeedsDataset(rf'{wd}\labels.csv', rf'{wd}\images', preprocess)
weeds.__loaddata__()
weeds.__split__()
weeds.__augment__(0.5)

Data has been loaded
Data has been split
Size of training set is 14007
Size of validation set is 1751
Size of test set is 1751
Data has been augmented
Size of training set is 21012


In [8]:
# Functions to convert the label back into the class
def convert_label(label, dataset):
    return dataset.classes[label.item()]

def convert_labels(labels,dataset):
    lst = []
    for label in labels:
        data_class = convert_label(label,dataset)
        lst.append(data_class)
    return lst

## Load data

In [20]:
# Sends image and label tensors to gpu
def get_batch(dataset, current_index, batch_size):
    images,labels = zip(*dataset[current_index:current_index+batch_size])
    return torch.stack(images, dim = 0).to(device), torch.stack(labels).to(device)

In [10]:
# Check that everything is working properly
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

# Create a grid from the images and show them
#images,labels = get_batch(weeds.train, 1000, 32)
#img_grid = torchvision.utils.make_grid(images.cpu())
#matplotlib_imshow(img_grid, one_channel=True)
#print(convert_labels(labels.cpu(),weeds))

## Replacing classification head

In [11]:
for param in swin_model.parameters():
    param.requires_grad = False
layers = []
layers.append(nn.Linear(in_features = 768, out_features = 256, bias = True))
layers.append(nn.ReLU())
layers.append(nn.Linear(in_features = 256, out_features = 9, bias = True))
swin_model.head = nn.Sequential(*layers).to(device)

In [12]:
#info.summary(swin_model)

## Training

In [13]:
loss_fn = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(swin_model.parameters(), lr=0.0012, momentum = 0.9)

In [14]:
def train_one_epoch(batch_size):
    batches = math.floor(len(weeds.train)/batch_size)
    current_index = 0
    running_loss = 0
    random.shuffle(weeds.train)

    for i in range(batches):
        inputs,labels = get_batch(weeds.train, current_index, batch_size)
        current_index += batch_size
        optimizer.zero_grad(set_to_none=True)
        outputs = swin_model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.cpu().item()

    remainder = len(weeds.train)%batch_size
    if remainder != 0:
        inputs,labels = get_batch(weeds.train, current_index, batch_size)
        optimizer.zero_grad()
        outputs = swin_model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.cpu().item()
        batches += 1
    
    return running_loss/batches

In [15]:
def validate(batch_size):
    batches = math.floor(len(weeds.valid)/batch_size)
    current_index = 0
    running_loss = 0
    
    for i in range(batches):
        inputs,labels = get_batch(weeds.valid, current_index, batch_size)
        current_index += batch_size
        output = swin_model(inputs)
        running_loss += loss_fn(output, labels).cpu().item()

    remainder = len(weeds.valid)%batch_size
    if remainder != 0:
        inputs,labels = get_batch(weeds.valid, current_index, batch_size)
        output = swin_model(inputs)
        running_loss += loss_fn(output, labels).cpu().item()
        batches += 1
    
    return running_loss/batches

In [16]:
def train(epochs, batch_size, epoch_counter, folder_name, save):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    current_vloss = 0
    start = time.perf_counter()
    rounds = 1
    
    if not os.path.exists('{}\{}'.format(wd,folder_name)):
        os.makedirs('{}\{}'.format(wd,folder_name))
        
    for epoch in range(epochs):
        print('EPOCH {}:'.format(epoch_counter + 1))
        swin_model.train()
        avg_loss = train_one_epoch(batch_size)
        
        swin_model.eval()
        with torch.no_grad():
            current_vloss = validate(batch_size)

        # Only save last {save} rounds
        if rounds + save >= epochs:
            model_path = '{}\model_{}_{}'.format(folder_name, timestamp, epoch_counter + 1)
            torch.save(swin_model.state_dict(), model_path)
        rounds += 1

        print('LOSS train {} valid {}'.format(avg_loss, current_vloss))
        epoch_counter += 1

    end = time.perf_counter()
    print('Average time taken per epoch is {}s'.format((end - start)/epochs))
    return

In [19]:
train(75,50,0,'finalmodel',10)

EPOCH 1:
LOSS train 1.3283949804419293 valid 0.9263197030458186
EPOCH 2:
LOSS train 0.8506016942922406 valid 0.6950112423963017
EPOCH 3:
LOSS train 0.6869847155098677 valid 0.5846914268202252
EPOCH 4:
LOSS train 0.5984274149648934 valid 0.5169751710361905
EPOCH 5:
LOSS train 0.5416168422568722 valid 0.4809159982121653
EPOCH 6:
LOSS train 0.5024519414771481 valid 0.43552979040477013
EPOCH 7:
LOSS train 0.4674713804503801 valid 0.4108854437039958
EPOCH 8:
LOSS train 0.4395132194009926 valid 0.3928733350088199
EPOCH 9:
LOSS train 0.4189249133412459 valid 0.378219054494467
EPOCH 10:
LOSS train 0.4046234784740733 valid 0.3646091506299045
EPOCH 11:
LOSS train 0.383369430687252 valid 0.3567312653693888
EPOCH 12:
LOSS train 0.3714211008327874 valid 0.34069733663151663
EPOCH 13:
LOSS train 0.36054087286890263 valid 0.33041859459545875
EPOCH 14:
LOSS train 0.3486629986600185 valid 0.3200023152037627
EPOCH 15:
LOSS train 0.336873929885674 valid 0.3180439403384096
EPOCH 16:
LOSS train 0.3265971246

In [20]:
# Run another 10 epochs to see results
train(10,50,75,'finalmodel',10)

EPOCH 76:
LOSS train 0.13428126191919462 valid 0.22334949125070125
EPOCH 77:
LOSS train 0.1325804835136741 valid 0.21746236010868517
EPOCH 78:
LOSS train 0.1293237548179907 valid 0.21840984343240658
EPOCH 79:
LOSS train 0.12864787207646494 valid 0.21463559099680019
EPOCH 80:
LOSS train 0.13038201406686578 valid 0.2208203960261825
EPOCH 81:
LOSS train 0.12903293339433836 valid 0.22080236091278493
EPOCH 82:
LOSS train 0.12964104047385785 valid 0.22252882631599075
EPOCH 83:
LOSS train 0.12572059669308452 valid 0.22042685384965605
EPOCH 84:
LOSS train 0.1213154908747472 valid 0.23572370798016587
EPOCH 85:
LOSS train 0.12662270518723265 valid 0.2172468727351063
Average time taken per epoch is 70.88935907000004s


In [21]:
# Epoch 75 has lowest validation loss so continue from there with lower lr and momentum
swin_model.load_state_dict(torch.load('finalmodel\model_20250222_192234_75', weights_only = True))
optimizer = torch.optim.SGD(swin_model.parameters(), lr=0.0008, momentum = 0.4)
train(10,50,75,'finalmodel',10)

EPOCH 76:
LOSS train 0.12689872474947198 valid 0.21235818152005473
EPOCH 77:
LOSS train 0.12964374058815759 valid 0.21381686332946023
EPOCH 78:
LOSS train 0.1300007206854585 valid 0.21204141070807558
EPOCH 79:
LOSS train 0.12715526632610663 valid 0.21365343525798786
EPOCH 80:
LOSS train 0.126066409825677 valid 0.2118403399751211
EPOCH 81:
LOSS train 0.12428464481691141 valid 0.214892825189357
EPOCH 82:
LOSS train 0.12518375067398696 valid 0.21103591552107698
EPOCH 83:
LOSS train 0.12449970492770157 valid 0.21073284458058575
EPOCH 84:
LOSS train 0.1269807388957351 valid 0.2100775209772918
EPOCH 85:
LOSS train 0.12738710942273326 valid 0.21238866809289902
Average time taken per epoch is 73.46449511s


In [22]:
# Run another 10 epochs with same settings
train(10,50,85,'finalmodel',10)

EPOCH 86:
LOSS train 0.1257564882458035 valid 0.21137572122582546
EPOCH 87:
LOSS train 0.12660341860279617 valid 0.2110931584870236
EPOCH 88:
LOSS train 0.1240571275344583 valid 0.21247096922403821
EPOCH 89:
LOSS train 0.1237922689684992 valid 0.21012785548292515
EPOCH 90:
LOSS train 0.12635042352907164 valid 0.2110995243355218
EPOCH 91:
LOSS train 0.12650962952179326 valid 0.2069510916610145
EPOCH 92:
LOSS train 0.12564620612335176 valid 0.20956924473608118
EPOCH 93:
LOSS train 0.12527784655609495 valid 0.21015708233850697
EPOCH 94:
LOSS train 0.12266297681332343 valid 0.21195605765872946
EPOCH 95:
LOSS train 0.1237290831495103 valid 0.20908217989684394
Average time taken per epoch is 72.76081088000001s


In [23]:
# Continue from epoch 91
swin_model.load_state_dict(torch.load('finalmodel\model_20250222_211532_91', weights_only = True))
optimizer = torch.optim.SGD(swin_model.parameters(), lr=0.0004, momentum = 0.1)
train(9,50,91,'finalmodel',9)

EPOCH 92:
LOSS train 0.12177798058616436 valid 0.21005653895230758
EPOCH 93:
LOSS train 0.12225227492845965 valid 0.21112188121252176
EPOCH 94:
LOSS train 0.12328396879596694 valid 0.20972251613986576
EPOCH 95:
LOSS train 0.1237388349662484 valid 0.21043359405464596
EPOCH 96:
LOSS train 0.12445266862486151 valid 0.21035036189843798
EPOCH 97:
LOSS train 0.1244147598495676 valid 0.2111690689990711
EPOCH 98:
LOSS train 0.12340053845188113 valid 0.2107173830865779
EPOCH 99:
LOSS train 0.12230413363286406 valid 0.211038783710036
EPOCH 100:
LOSS train 0.11898544694865543 valid 0.21020858390774164
Average time taken per epoch is 70.43590075555565s


In [27]:
#Try again with lower lr and momentum
swin_model.load_state_dict(torch.load('finalmodel\model_20250222_211532_91', weights_only = True))
optimizer = torch.optim.SGD(swin_model.parameters(), lr=0.0003, momentum = 0)
train(9,50,91,'finalmodel',9)

EPOCH 92:
LOSS train 0.12165545089935746 valid 0.20989099085434443
EPOCH 93:
LOSS train 0.12424832803248084 valid 0.21047276625823644
EPOCH 94:
LOSS train 0.1239845143794801 valid 0.20992867788299918
EPOCH 95:
LOSS train 0.12296401545617756 valid 0.2098694044010093
EPOCH 96:
LOSS train 0.12320876541832847 valid 0.21060868245290798
EPOCH 97:
LOSS train 0.12088567289278229 valid 0.2098168595564655
EPOCH 98:
LOSS train 0.1220386895946185 valid 0.21136812385844272
EPOCH 99:
LOSS train 0.12061398425482373 valid 0.21173184525428546
EPOCH 100:
LOSS train 0.12273840095644326 valid 0.21007908748773238
Average time taken per epoch is 72.07202336666685s


## Evaluation

In [None]:
# Lowest validation loss is 0.2069510916610145 at epoch 91

In [17]:
swin_model.load_state_dict(torch.load('finalmodel\model_20250222_211532_91', weights_only = True))

<All keys matched successfully>

In [18]:
def evaluate(batch_size):
    start = time.perf_counter()
    total = 0
    correct = 0
    batches = math.floor(len(weeds.test)/batch_size)
    current_index = 0
    
    for i in range(batches):
        inputs,labels = get_batch(weeds.test, current_index, batch_size)
        current_index += batch_size
        prediction = torch.argmax(swin_model(inputs), dim = 1)
        correct += sum(prediction == labels).cpu().item()
        total += batch_size

    remainder = len(weeds.test)%batch_size
    if remainder != 0:
        inputs,labels = get_batch(weeds.test, current_index, remainder)
        prediction = torch.argmax(swin_model(inputs), dim = 1)
        correct += sum(prediction == labels).cpu().item()
        total += remainder

    accuracy = 100*correct/total
    print ('Accuracy is {}%'.format(accuracy))
    end = time.perf_counter()
    print('Time taken is {}'.format(end-start))
    
    return 

In [21]:
with torch.no_grad():
    swin_model.eval()
    evaluate(20)

Accuracy is 92.57567104511708%
Time taken is 5.432378800003789


In [None]:
# Accuracy is 92.58%

In [36]:
torch.save(swin_model, 'trained_model\weedsv1.pt')