In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import random
from datetime import datetime
import time
import math

import torch
import torchvision
import torch.nn as nn
import torchinfo as info
from torch.utils.data import Dataset
from torchvision.transforms import functional
from torchvision.transforms import v2

In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
set_seed(500)

In [3]:
wd = os.getcwd()
device = torch.device('cuda')

## Import Swin_V2 Model

In [4]:
from torchvision.models import swin_v2_b, Swin_V2_B_Weights

In [5]:
swin_model = swin_v2_b(weights=Swin_V2_B_Weights.DEFAULT).to(device)
preprocess = Swin_V2_B_Weights.DEFAULT.transforms()

In [6]:
preprocess

ImageClassification(
    crop_size=[256]
    resize_size=[272]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BICUBIC
)

## Load and Process Data

In [7]:
# Class representing a set
class WeedsDataset(Dataset):
    def __init__(self, labels, img_dir, transform = None):
        self.labels = pd.read_csv(labels)
        self.img_dir = img_dir
        self.transform = transform
        self.classes = self.labels[['Label','Species']].drop_duplicates().sort_values(by = 'Label').reset_index(drop = True)['Species']
        self.data = []
        self.train = []
        self.validation = []
        self.test = []
        
    # Load entire dataset into memory
    def __loaddata__(self):
        for row in self.labels.itertuples():
            filename = row.Filename
            label = torch.tensor(row.Label)
            img_path = os.path.join(self.img_dir, filename)
            image = self.transform(torchvision.io.read_image(img_path))
            self.data.append([image,label])
        del self.labels
        print('Data has been loaded')

    # Split into train/validation/test
    def __split__(self):
        random.shuffle(self.data)
        n = len(self.data)
        n_train = round(n * 0.8)
        n_valid = round(n * 0.1)
        
        self.train = self.data[:n_train]
        self.valid = self.data[n_train:n_train + n_valid]
        self.test = self.data[n_train + n_valid:]
        del self.data
        
        print('Data has been split')
        print('Size of training set is {}'.format(len(self.train)))
        print('Size of validation set is {}'.format(len(self.valid)))
        print('Size of test set is {}'.format(len(self.test)))

    # Augment with rotations and gaussian noise
    def __augment__(self, frac):
        n = round(len(self.train) * frac / 5)
        rotate_90 = random.sample(self.train, n)
        for row in rotate_90:
            image = row[0]
            row[0] = functional.rotate(image, 90)
        rotate_180 = random.sample(self.train, n)
        for row in rotate_180:
            image = row[0]
            row[0] = functional.rotate(image, 180)
        rotate_270 = random.sample(self.train,n)
        for row in rotate_270:
            image = row[0]
            row[0] = functional.rotate(image, 270)
        self.train += rotate_90 + rotate_180 + rotate_270
        gaussian = v2.GaussianNoise(0,0.08,False)
        noise = random.sample(self.train, 2*n)
        for row in noise:
            image = row[0]
            row[0] = gaussian(image)
        self.train += noise
        
        print('Data has been augmented')
        print('Size of training set is {}'.format(len(self.train)))
        
    # Get item from sets, format is (image,label)
    def __getitem__(self, idx, split):
        if split == 'train':
            item = self.train[idx]
        elif split == 'valid':
            item = self.valid[idx]
        elif split == 'test':
            item = self.test[idx]
        return item[0], item[1]

In [8]:
# Run functions to load and process data
weeds = WeedsDataset(rf'{wd}\labels.csv', rf'{wd}\images', preprocess)
weeds.__loaddata__()
weeds.__split__()
weeds.__augment__(0.5)

Data has been loaded
Data has been split
Size of training set is 14007
Size of validation set is 1751
Size of test set is 1751
Data has been augmented
Size of training set is 21012


In [9]:
# Functions to convert the label back into the class
def convert_label(label, dataset):
    return dataset.classes[label.item()]

def convert_labels(labels,dataset):
    lst = []
    for label in labels:
        data_class = convert_label(label,dataset)
        lst.append(data_class)
    return lst

## Load data

In [10]:
# Sends image and label tensors to gpu
def get_batch(dataset, current_index, batch_size):
    images,labels = zip(*dataset[current_index:current_index+batch_size])
    return torch.stack(images, dim = 0).to(device), torch.stack(labels).to(device)

In [21]:
# Check that everything is working properly
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

# Create a grid from the images and show them
#images,labels = get_batch(weeds.train, 1000, 32)
#img_grid = torchvision.utils.make_grid(images.cpu())
#matplotlib_imshow(img_grid, one_channel=True)
#print(convert_labels(labels.cpu(),weeds))

## Replacing classification head

In [11]:
for param in swin_model.parameters():
    param.requires_grad = False
layers = []
layers.append(nn.Linear(in_features = 1024, out_features = 256, bias = True))
layers.append(nn.ReLU())
layers.append(nn.Linear(in_features = 256, out_features = 9, bias = True))
swin_model.head = nn.Sequential(*layers).to(device)

In [12]:
swin_model.head

Sequential(
  (0): Linear(in_features=1024, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=9, bias=True)
)

In [None]:
#info.summary(swin_model)

## Training

In [13]:
loss_fn = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(swin_model.parameters(), lr=0.0012, momentum = 0.9)

In [14]:
def train_one_epoch(batch_size):
    batches = math.floor(len(weeds.train)/batch_size)
    current_index = 0
    running_loss = 0
    random.shuffle(weeds.train)

    for i in range(batches):
        inputs,labels = get_batch(weeds.train, current_index, batch_size)
        current_index += batch_size
        optimizer.zero_grad(set_to_none=True)
        outputs = swin_model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.cpu().item()

    remainder = len(weeds.train)%batch_size
    if remainder != 0:
        inputs,labels = get_batch(weeds.train, current_index, batch_size)
        optimizer.zero_grad()
        outputs = swin_model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.cpu().item()
        batches += 1
    
    return running_loss/batches

In [15]:
def validate(batch_size):
    batches = math.floor(len(weeds.valid)/batch_size)
    current_index = 0
    running_loss = 0
    
    for i in range(batches):
        inputs,labels = get_batch(weeds.valid, current_index, batch_size)
        current_index += batch_size
        output = swin_model(inputs)
        running_loss += loss_fn(output, labels).cpu().item()

    remainder = len(weeds.valid)%batch_size
    if remainder != 0:
        inputs,labels = get_batch(weeds.valid, current_index, batch_size)
        output = swin_model(inputs)
        running_loss += loss_fn(output, labels).cpu().item()
        batches += 1
    
    return running_loss/batches

In [16]:
def train(epochs, batch_size, epoch_counter, folder_name, save):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    current_vloss = 0
    start = time.perf_counter()
    rounds = 1
    
    if not os.path.exists('{}\{}'.format(wd,folder_name)):
        os.makedirs('{}\{}'.format(wd,folder_name))
        
    for epoch in range(epochs):
        print('EPOCH {}:'.format(epoch_counter + 1))
        swin_model.train()
        avg_loss = train_one_epoch(batch_size)
        
        swin_model.eval()
        with torch.no_grad():
            current_vloss = validate(batch_size)

        # Only save last {save} rounds
        if rounds + save >= epochs:
            model_path = '{}\model_{}_{}'.format(folder_name, timestamp, epoch_counter + 1)
            torch.save(swin_model.state_dict(), model_path)
        rounds += 1

        print('LOSS train {} valid {}'.format(avg_loss, current_vloss))
        epoch_counter += 1

    end = time.perf_counter()
    print('Average time taken per epoch is {}s'.format((end - start)/epochs))
    return

In [29]:
train(50,50,0,'testmodel',10)

EPOCH 1:
LOSS train 1.3390585229402483 valid 0.9432550147175789
EPOCH 2:
LOSS train 0.8812311181263233 valid 0.7294931982954344
EPOCH 3:
LOSS train 0.723507198471921 valid 0.6235770475533273
EPOCH 4:
LOSS train 0.6283326989964465 valid 0.5532150583134757
EPOCH 5:
LOSS train 0.5730355395415616 valid 0.5062037453883224
EPOCH 6:
LOSS train 0.5310607932667268 valid 0.45651991789539653
EPOCH 7:
LOSS train 0.49577513542424473 valid 0.45451483585768276
EPOCH 8:
LOSS train 0.46962585737331464 valid 0.40802961587905884
EPOCH 9:
LOSS train 0.445639189483151 valid 0.4093300679491626
EPOCH 10:
LOSS train 0.42832543242572324 valid 0.38590303984367186
EPOCH 11:
LOSS train 0.41696283238942156 valid 0.36478071970244247
EPOCH 12:
LOSS train 0.403304303733017 valid 0.35426897183060646
EPOCH 13:
LOSS train 0.3862745182460391 valid 0.3463750943127606
EPOCH 14:
LOSS train 0.3713138032341796 valid 0.3371264518549045
EPOCH 15:
LOSS train 0.3663419714095757 valid 0.3275220168547498
EPOCH 16:
LOSS train 0.3563

In [30]:
# Run another 25 epochs
train(25,50,50,'testmodel',10)

EPOCH 51:
LOSS train 0.21138129660288965 valid 0.23275276894370714
EPOCH 52:
LOSS train 0.20749398246584483 valid 0.23221596748205936
EPOCH 53:
LOSS train 0.20058914164253766 valid 0.2224171869456768
EPOCH 54:
LOSS train 0.203108948662063 valid 0.23644539509485993
EPOCH 55:
LOSS train 0.20035356676068838 valid 0.2310625351448026
EPOCH 56:
LOSS train 0.19889075678027723 valid 0.22668670932115573
EPOCH 57:
LOSS train 0.19711799507257208 valid 0.23780708434060216
EPOCH 58:
LOSS train 0.19501334040330595 valid 0.22438563189158836
EPOCH 59:
LOSS train 0.19186643657944832 valid 0.22459907255445918
EPOCH 60:
LOSS train 0.19255517432886463 valid 0.2230653534902053
EPOCH 61:
LOSS train 0.19230194224792393 valid 0.22709874242233732
EPOCH 62:
LOSS train 0.18994908401136443 valid 0.2247819554629839
EPOCH 63:
LOSS train 0.18788126959618084 valid 0.21815137513395813
EPOCH 64:
LOSS train 0.1830734115828386 valid 0.22074600856285542
EPOCH 65:
LOSS train 0.1783893779982014 valid 0.21850013782063293
EPO

In [31]:
# Run another 10 epochs
train(10,50,75,'testmodel',10)

EPOCH 76:
LOSS train 0.16367355166910097 valid 0.2128411340397886
EPOCH 77:
LOSS train 0.16301728439939844 valid 0.21041587779634735
EPOCH 78:
LOSS train 0.1630751442234972 valid 0.22084775999084943
EPOCH 79:
LOSS train 0.1587625771778638 valid 0.21552396984770894
EPOCH 80:
LOSS train 0.1595259168121141 valid 0.20760167340308222
EPOCH 81:
LOSS train 0.15745674547542443 valid 0.20671661055853796
EPOCH 82:
LOSS train 0.15689195425309366 valid 0.20413677550904039
EPOCH 83:
LOSS train 0.15562974101467825 valid 0.20506172221050495
EPOCH 84:
LOSS train 0.15468492368214062 valid 0.220046156590494
EPOCH 85:
LOSS train 0.15559231555376937 valid 0.2315680438057623
Average time taken per epoch is 172.6834253799985s


In [32]:
# Epoch 82 has lowest validation loss so continue from there with lower lr and momentum
swin_model.load_state_dict(torch.load('testmodel\model_20250225_112924_82', weights_only = True))
optimizer = torch.optim.SGD(swin_model.parameters(), lr=0.0008, momentum = 0.4)
train(8,50,82,'testmodel',8)

EPOCH 83:
LOSS train 0.14972701666969584 valid 0.20491910196789023
EPOCH 84:
LOSS train 0.14959352199350995 valid 0.2066789831992032
EPOCH 85:
LOSS train 0.14195913106486235 valid 0.20553429185464564
EPOCH 86:
LOSS train 0.14736420846590662 valid 0.20875638992422157
EPOCH 87:
LOSS train 0.14610925996388505 valid 0.20624750112700793
EPOCH 88:
LOSS train 0.14516572598943092 valid 0.2068008242123243
EPOCH 89:
LOSS train 0.1464302123805194 valid 0.20681307586427364
EPOCH 90:
LOSS train 0.14122015711212102 valid 0.20444773777853698
Average time taken per epoch is 171.6681134999999s


In [33]:
# Another 5 epochs
train(5,50,90,'testmodel',5)

EPOCH 91:
LOSS train 0.14382885376128618 valid 0.2046591496715943
EPOCH 92:
LOSS train 0.14452524261102256 valid 0.2064298023469746
EPOCH 93:
LOSS train 0.14626077479835503 valid 0.20892401467750055
EPOCH 94:
LOSS train 0.14717311182915457 valid 0.20388003129563811
EPOCH 95:
LOSS train 0.14596272116159742 valid 0.20464123661319414
Average time taken per epoch is 165.95095590000273s


In [38]:
# Another 15 epochs
train(15,50,95,'testmodel',15)

EPOCH 96:
LOSS train 0.1446822808520103 valid 0.20132448068276668
EPOCH 97:
LOSS train 0.14158068325220122 valid 0.2035724569319023
EPOCH 98:
LOSS train 0.14351433894714238 valid 0.20157652236391893
EPOCH 99:
LOSS train 0.14470897748128395 valid 0.20637125166184786
EPOCH 100:
LOSS train 0.1420655709972291 valid 0.20792183633117625
EPOCH 101:
LOSS train 0.14546317078456317 valid 0.2028987623570073
EPOCH 102:
LOSS train 0.14408251759983298 valid 0.20585942750848416
EPOCH 103:
LOSS train 0.14432914624092413 valid 0.20309232489671558
EPOCH 104:
LOSS train 0.14650239195948966 valid 0.20421617870063832
EPOCH 105:
LOSS train 0.14196060319163453 valid 0.20999597037573242
EPOCH 106:
LOSS train 0.1420444050082614 valid 0.2078818557654611
EPOCH 107:
LOSS train 0.14309469083956755 valid 0.20478275088438144
EPOCH 108:
LOSS train 0.13905824877480996 valid 0.20949100140326968
EPOCH 109:
LOSS train 0.1403702573049946 valid 0.20431711139260894
EPOCH 110:
LOSS train 0.14525314658490893 valid 0.204240976

In [40]:
# Lower lr further from epoch 96
swin_model.load_state_dict(torch.load('testmodel\model_20250225_123744_96', weights_only = True))
optimizer = torch.optim.SGD(swin_model.parameters(), lr=0.0004, momentum = 0.2)
train(14,50,96,'testmodel',14)

EPOCH 97:
LOSS train 0.14231866737938267 valid 0.20359141475960818
EPOCH 98:
LOSS train 0.14646978036560243 valid 0.20504553794550398
EPOCH 99:
LOSS train 0.14750776579714162 valid 0.2046950082035942
EPOCH 100:
LOSS train 0.1424761791702475 valid 0.20380187580465442
EPOCH 101:
LOSS train 0.14044551499577287 valid 0.20409172483616406
EPOCH 102:
LOSS train 0.13970945496775758 valid 0.20454917751097432
EPOCH 103:
LOSS train 0.14664217901378515 valid 0.20606204584085694
EPOCH 104:
LOSS train 0.14437565368748406 valid 0.20615295729496413
EPOCH 105:
LOSS train 0.14174389667508155 valid 0.20559932613590112
EPOCH 106:
LOSS train 0.14457347820515162 valid 0.2044525071590518
EPOCH 107:
LOSS train 0.14289272550024498 valid 0.20541364148569605
EPOCH 108:
LOSS train 0.14332885559727915 valid 0.20528541850702217
EPOCH 109:
LOSS train 0.144936253402975 valid 0.20464071342980283
EPOCH 110:
LOSS train 0.1417920665501345 valid 0.20263445753759393
Average time taken per epoch is 159.89958819285883s


## Evaluation

In [None]:
# Lowest validation loss is 0.20132448068276668 at epoch 96

In [24]:
swin_model.load_state_dict(torch.load('testmodel\model_20250225_123744_96', weights_only = True))

<All keys matched successfully>

In [22]:
def evaluate(batch_size):
    start = time.perf_counter()
    total = 0
    correct = 0
    batches = math.floor(len(weeds.test)/batch_size)
    current_index = 0
    
    for i in range(batches):
        inputs,labels = get_batch(weeds.test, current_index, batch_size)
        current_index += batch_size
        prediction = torch.argmax(swin_model(inputs), dim = 1)
        correct += sum(prediction == labels).cpu().item()
        total += batch_size

    remainder = len(weeds.test)%batch_size
    if remainder != 0:
        inputs,labels = get_batch(weeds.test, current_index, remainder)
        prediction = torch.argmax(swin_model(inputs), dim = 1)
        correct += sum(prediction == labels).cpu().item()
        total += remainder

    accuracy = 100*correct/total
    print ('Accuracy is {}%'.format(accuracy))
    end = time.perf_counter()
    print('Time taken is {}'.format(end-start))
    
    return 

In [25]:
with torch.no_grad():
    swin_model.eval()
    evaluate(20)

Accuracy is 93.60365505425472%
Time taken is 12.76658049999969


In [None]:
# Accuracy is 93.6%

In [43]:
torch.save(swin_model, 'trained_model\weedsv2.pt')