In [None]:
import os
import numpy as np
import torch
from torch import nn
from torch.nn import ReLU
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import torch.multiprocessing
from torchvision import datasets, transforms
from torchvision import models
from torch import optim
from torch.utils.data.dataloader import default_collate

from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import torchvision
import matplotlib.pyplot as plt
from colorama import Fore
from IPython.display import Audio, display
from torchsummary import summary
from tqdm.auto import tqdm
from timeit import default_timer as timer
import psutil

#log setup
writer_path = os.path.join('runs', 'logger_classifier')
os.makedirs(writer_path, exist_ok=True)
writer = SummaryWriter(writer_path)

Definizione delle variabili globali principali.
I wokers sono i thread paralleli che caricano i batch del dataset nelle varie epoche.

In [None]:
NUM_WORKERS = 4
DATASET_PATH = os.path.join("..", "spectrograms")
IMAGE_SIZE = (839, 351)
CHANNEL_COUNT = 3
ACCURACY_THRESHOLD = 85
MAX_EPOCHS = 300
LEARNING_RATE = 0.01
GRADIENT_MOMENTUM = 0.90

torch.multiprocessing.set_sharing_strategy('file_system')

Viene caricato il dataset. datasets.ImageFolder è un metodo di pytorch che carica il dataset sapendo che le immagini sono divise in classi a seconda del nome della cartella in cui sono contenute, viene anche specificato cosa fare dei dati caricati, ovvero una trasformazione in tensore.
Infine viene fatto lo split tra train e test.

In [None]:
# Define the data transformation
transform=transforms.ToTensor() 

# Load the dataset
print(Fore.LIGHTMAGENTA_EX + f"Loading images from dataset at {DATASET_PATH}")
dataset = datasets.ImageFolder(DATASET_PATH, transform=transform)

# train / test split
val_ratio = 0.2
val_size = int(val_ratio * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
print(Fore.GREEN + f"{train_size} images for training, {val_size} images for validation")

Divisione in batch di train e validation set

In [None]:
batch_size = 16


# Load into batches
train_batches = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=NUM_WORKERS,
                                           pin_memory=False) # switch to True if using collate

val_batches = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=batch_size*2,
                                         num_workers=NUM_WORKERS,
                                         pin_memory=False) # switch to True if using collate

print(Fore.LIGHTMAGENTA_EX + f"Dataset loaded in batches.")
print(Fore.GREEN + f"Batch set to {batch_size} for training")
print(Fore.GREEN + f"Batch set to {batch_size*2} for validation")

Definizione della rete

In [None]:
# define CNN as sequential
class neuralNetworkV1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1) 
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1) 
        self.conv3 = nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(10, 10, kernel_size=3, stride=2, padding=1)
        self.pooling = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) 
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(in_features=780, out_features=50)
    
    def forward(self, x: torch.Tensor):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.pooling(x)
        x = self.relu(self.conv3(x))
        x = self.pooling(x)
        x = self.relu(self.conv4(x))
        x = self.flatten(x)
        try:
            x = self.linear(x)
        except Exception as e:
            print(Fore.RED + f"Error : Linear block should take support shape of {x.shape} for in_features.")
        return x

selected_model = neuralNetworkV1()
train_images_sample, _ = next(iter(train_batches))

In [None]:
print(Fore.LIGHTMAGENTA_EX + "Model summary : " + Fore.GREEN)
print(summary(selected_model, (CHANNEL_COUNT, IMAGE_SIZE[0], IMAGE_SIZE[1])))
writer.add_graph(selected_model, train_images_sample)

Funzioni di utility per l'addestramento della rete

In [None]:
# display total time training
def display_training_time(start, end, device):
    total_time = end - start
    print(Fore.LIGHTMAGENTA_EX + f"Train time on {device}: {total_time:.3f} seconds")
    return total_time

# Calculate accuracy
def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100
    return acc

# Display training infos for each epochs
def display_training_infos(epoch, val_loss, train_loss, accuracy):
    val_loss = round(val_loss.item(), 3)
    train_loss = round(train_loss.item(), 3)
    accuracy = round(accuracy, 2)
    print(Fore.GREEN + f"Epoch : {epoch}, Training loss : {train_loss}, Validation loss : {val_loss}, Accuracy : {accuracy} %")

# Check memory usage excess
def check_memory():
    mem_percent = psutil.virtual_memory().percent
    swap_percent = psutil.swap_memory().percent
    if mem_percent >= 90:
        print(Fore.YELLOW + f"WARNING : Reached {mem_percent} memory usage !")
        os.system(f'say "Memory usage high"')
    if swap_percent >= 90:
        print(Fore.YELLOW + f"WARNING : Reached {mem_percent} memory usage !")
        os.system(f'say "Swap usage high"')
    if mem_percent >= 95 and swap_percent >= 95:
        print(Fore.RED + f"ABORTING : Memory and Swap full !")
        os.system(f'say "Aborting training"')
        raise MemoryError


Training della rete.

In [None]:
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(selected_model.parameters(), lr=LEARNING_RATE, momentum=GRADIENT_MOMENTUM)

def train_neural_net(epochs, model, loss_func, optimizer, train_batches, val_batches):
    last_loss = 0
    final_accuracy = 0
    for epoch in tqdm(range(epochs)):
        # check memory and swap usage
        check_memory()
        # training mode
        model.train()
        with torch.enable_grad():
            train_loss = 0
            for images, labels in train_batches:
                predictions = model(images)
                loss = loss_func(predictions, labels)
                train_loss += loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            train_loss /= len(train_batches)
            writer.add_scalar("training loss", train_loss, epoch)
        # evaluation mode
        val_loss, val_accuracy = 0, 0
        model.eval()
        with torch.inference_mode():
            for images, labels in val_batches:
                predictions = model(images)
                val_loss += loss_func(predictions, labels)
                val_accuracy += accuracy_fn(y_true=labels, y_pred=predictions.argmax(dim=1))
            val_loss /= len(val_batches)
            val_accuracy /= len(val_batches)
            writer.add_scalar("validation loss", val_loss, epoch)
            final_accuracy = val_accuracy
        display_training_infos(epoch+1, val_loss, train_loss, val_accuracy)
        writer.add_scalar("accuracy", val_accuracy, epoch)
        if val_accuracy >= ACCURACY_THRESHOLD:
            break
        last_loss = val_loss
    return final_accuracy

print(Fore.LIGHTMAGENTA_EX + "Model ready : ")
print(Fore.GREEN, f"Learning rate set to : {LEARNING_RATE}")
print(Fore.GREEN, f"Momentum set to : {GRADIENT_MOMENTUM}")

print(Fore.LIGHTMAGENTA_EX + "Starting model training...")
train_time_start_on_gpu = timer()
training_complete = False
model_accuracy = train_neural_net(MAX_EPOCHS, selected_model, loss_func, optimizer, train_batches, val_batches)
print(Fore.LIGHTCYAN_EX + f"Training complete : {model_accuracy} %")
os.system(f'say "Training complete"')
training_complete = True
display_training_time(start=train_time_start_on_gpu,
                  end=timer(),
                  device=compute_unit)