### Notebook for training the Litter Detection Model (PLD) on custom data

Import libraries

In [None]:
import os
import yaml
import requests
import itertools
import json
import torch
import torch.nn as nn
import torchvision.transforms.v2 as v2
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torchvision import models
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from sklearn.model_selection import train_test_split
from PIL import Image

Define methods for training

In [None]:
def download_model(destination_dir='./models',
                   base_url='https://share.services.ai4os.eu/s/dHNm3RdqYWaoWKj/download/',
                   model_filename='densenet121_PLQ.pth'):

    os.makedirs(destination_dir, exist_ok=True)
    full_url = f"{base_url}{model_filename}"
    destination_path = os.path.join(destination_dir, model_filename)

    try:
        response = requests.get(full_url, verify=False, stream=True)
        response.raise_for_status()  # Raise an error on bad status

        with open(destination_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)

        print(f"Model downloaded successfully to {destination_path}")
    except requests.exceptions.RequestException as e:
        print(f"Failed to download the model: {e}")

def load_model(model_path, device, num_old_classes, num_classes): 
    model=models.densenet121(pretrained=False)
    model.classifier=nn.Linear(model.classifier.in_features, num_old_classes)
    model.load_state_dict(torch.load(model_path, map_location=torch.device(device)),strict=False)
    model.classifier=nn.Linear(model.classifier.in_features, num_classes)
    return model

def modify_model(model, num_classes):
    model.classifier=nn.Linear(model.classifier.in_features, num_classes)
    return model   

def load_labels(label_path):
    with open(label_path, 'r') as f:
        label_file=yaml.safe_load(f)
    labels_old=label_file['label']['label old']
    labels=label_file['label']['label new']
    return labels_old, labels

def training_epoch(model, train_data_loader, optimizer, loss_fn, device):
    model.train()
    size = len(train_data_loader)
    correct, total = 0, 0
    for batch, (X, y) in enumerate(train_data_loader):
        optimizer.zero_grad()
        X = X.to(device)
        y = y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            print(f'loss: {loss.item():>.7} {batch}/{size}')
        correct += (torch.argmax(pred, 1) == y).sum().item()
        total += y.size(0)
    return loss, correct / total
    
def val(model, val_data_loader, loss_fn, device):
    model.eval()
    true_vals = list()
    pred_vals = list()
    val_loss = list()
    with torch.no_grad():
        for X, y in val_data_loader:
            X = X.to(device)
            pred = model(X).cpu()
            pred_vals.append(torch.argmax(pred, 1).numpy()[0])
            true_vals.append(y.numpy()[0])
            val_loss.append(loss_fn(pred, y).item())
    precision, recall, f1_score, _ = precision_recall_fscore_support(true_vals, pred_vals, average='weighted',zero_division=0)
    accuracy = accuracy_score(true_vals, pred_vals)
    return precision, recall, f1_score, np.mean(val_loss), accuracy
    
def train(num_epochs, model, train_data_loader,test_data_loader, optimizer, loss_fn, device, log_path):
    train_losses, val_losses = [], []
    train_acc, val_acc = [], []
    metric_results = { m: list() for m in ['loss', 'precision', 'recall', 'f1-score'] }
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}\n-------------------------------")
        train_loss, train_accuracy = training_epoch(model, train_data_loader, optimizer, loss_fn, device)
        metric = val(model, test_data_loader, loss_fn, device)
        val_loss=metric[3]
        val_accuracy = metric[4]

        val_losses.append(val_loss)
        train_losses.append(train_loss.item())
        train_acc.append(train_accuracy)
        val_acc.append(val_accuracy)
        
        metric_results['loss'].append(train_loss.cpu().item())
        metric_results['precision'].append(metric[0])
        metric_results['recall'].append(metric[1])
        metric_results['f1-score'].append(metric[2])
        
        if (epoch + 1) % 10 == 0:
            with open(os.path.join(log_path, f'metrics.json'), 'w') as file:
                json.dump(json.dumps(metric_results), file)
            torch.save(model.state_dict(), os.path.join(log_path, f'checkpoint_{epoch + 1}_weights.pth'))
        print(f"Precision: {metric[0]} Recall: {metric[1]} F1-Score: {metric[2]}")
    print("Done")
    return model, train_losses, val_losses, train_acc, val_acc


Methods for model testing

In [None]:
def test(model, test_data_loader, loss_fn, device):
    model.eval()
    true_vals = list()
    pred_vals = list()
    test_loss = list()
    with torch.no_grad():
        for X, y in test_data_loader:
            X = X.to(device)
            pred = model(X).cpu()
            pred_vals.append(torch.argmax(pred, 1).numpy()[0])
            true_vals.append(y.numpy()[0])
            test_loss.append(loss_fn(pred, y).item())
    precision, recall, f1_score, _ = precision_recall_fscore_support(true_vals, pred_vals, average='weighted',zero_division=0)
    return precision, recall, f1_score

def plot_losses(train_losses, test_losses):
    plt.figure(figsize=(8, 5))
    plt.plot(train_losses, label='Training', linewidth=3)
    plt.plot(test_losses, label='Validation', linewidth=3)
    plt.xlabel('Epochs', fontsize=22)
    plt.ylabel('Loss', fontsize=22)
    plt.xticks(fontsize=22)
    plt.yticks(fontsize=22)
    plt.title(f'DenseNet121 PLQ', fontsize=24)
    plt.legend(fontsize=22)
    plt.grid(True)
    plt.show()

def plot_acc(train_acc, test_acc):
    plt.figure(figsize=(8, 5))
    plt.plot(train_acc, label='Training', linewidth=3)
    plt.plot(test_acc, label='Validation', linewidth=3)
    plt.xlabel('Epochs', fontsize=22)
    plt.ylabel('Accuracy', fontsize=22)
    plt.xticks(fontsize=22)
    plt.yticks(fontsize=22)
    plt.title(f'DenseNet121 PLQ', fontsize=24)
    plt.legend(fontsize=22)
    plt.grid(True)
    plt.show()

Dataset class and data augmentation

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dir, labels, transform=None, split='train', seed=35, train_split=0.7, val_split=0.15):
        self.dir=dir
        self.classes=labels
        label_to_index = {label: index for index, label in enumerate(self.classes)}
        data = [[os.path.join(dir, u_dir, images) for images in os.listdir(os.path.join(dir, u_dir))] for u_dir in os.listdir(dir)]
        data = list(itertools.chain.from_iterable(data))
        self.data = np.array([[label_to_index[name.split('/')[-2]], name] for name in data if name.split('/')[-1][-3:] == 'png'])
        X_train, X_temp = train_test_split(self.data, train_size=train_split, random_state=seed, stratify=self.data[:, 0])
        X_val, X_test = train_test_split(X_temp, test_size=val_split / (1 - train_split), random_state=seed, stratify=X_temp[:, 0])
        
        if split == 'train':
            self.data = X_train
        elif split == 'val':
            self.data = X_val
        else:
            self.data = X_test 
        self.transform = transform
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        label, img_path = self.data[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img) 
        return img, int(label)
    
train_transform = v2.Compose([
            v2.PILToTensor(),
            v2.Resize((64, 64), interpolation=v2.InterpolationMode.NEAREST),
            v2.RandomHorizontalFlip(p=0.5),
            v2.RandomVerticalFlip(p=0.5),
            # v2.RandomRotation(45),
            # v2.ColorJitter(brightness=0.2),
            # v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            # v2.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.)),
            v2.ConvertImageDtype(torch.float32),
            # v2.Lambda(lambda x: x / 255.)
    ])

test_transform = v2.Compose([
            v2.PILToTensor(),
            v2.Resize((64, 64), interpolation=v2.InterpolationMode.NEAREST),
            v2.ConvertImageDtype(torch.float32),
            v2.Lambda(lambda x: x / 255.)
        ])

Main skript for running the training

In [None]:
#Set paths
model_path=os.path.join(os.getcwd(),'models', 'densenet121_PLQ.pth')
label_path=os.path.join(os.getcwd(), 'configs/labels.yaml')
data_path=os.path.join(os.getcwd(), 'data')
log_path=os.path.join(os.getcwd(), 'logs')
new_model_path=f'new_model.pth'

#download pretrained densenet121 weights
download_model()

#Set training parameters
num_epochs=100
learning_rate=0.001
batch_size=128
num_workers=1
device=torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")
print(f'Training is running on {device}')

#Load model and labels
labels_old, labels=load_labels(label_path)
num_new_classes=len(labels)
num_old_classes=len(labels_old)
model=load_model(model_path, device, num_old_classes, num_new_classes)

optimizer=Adam(model.parameters(), lr=learning_rate)
loss_fn = CrossEntropyLoss().to(device)

#Initialize train and test data loaders
train_dataset = CustomDataset(data_path, split='train', transform=train_transform, labels=labels)
val_dataset = CustomDataset(data_path, split='val', transform=test_transform, labels=labels)
test_dataset = CustomDataset(data_path, split='test', transform=test_transform, labels=labels)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_data_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=num_workers)

#Modify model to new input data
mod_model=modify_model(model, num_new_classes).to(device)

#Retrain modified model
finetuned_model,train_losses, val_losses, train_acc, val_acc =train(num_epochs, mod_model, train_data_loader, val_data_loader, optimizer, loss_fn, device, log_path)

#Save finetuned model
torch.save(finetuned_model.state_dict(), new_model_path)
print(f"Model saved: {new_model_path}")

Test the trained model

In [None]:
model_file='new_model.pth'

labels_old, labels=load_labels(label_path)
num_new_classes=len(labels)

device=torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")

model=models.densenet121(pretrained=False)
model.classifier=nn.Linear(model.classifier.in_features, num_new_classes)
model.load_state_dict(torch.load(model_file, map_location=torch.device(device)),strict=False)
model.to(device)

test_precision, test_recall, test_f1_score = test(model, test_data_loader, loss_fn, device)
print(f"Test Precision: {test_precision}")
print(f"Test Recall: {test_recall}")
print(f"Test F1-Score: {test_f1_score}")

plot_losses(train_losses, val_losses)
plot_acc(train_acc, val_acc)