# Import and drive

Import statements

In [1]:
import zipfile
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import math
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import models
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

Mount the drive

In [2]:
# mount the drive
from google.colab import drive
drive.mount("/content/drive/")

# enter the desired folder
%cd drive/MyDrive/Colab_environments/WoodClassification/

Mounted at /content/drive/
/content/drive/MyDrive/Colab_environments/WoodClassification


Unzip the training folder

In [3]:
# with zipfile.ZipFile(os.path.join(os.getcwd(), "TRAINING.zip"), 'r') as zip_ref:
#      zip_ref.extractall(os.path.join(os.getcwd(), "TRAINING"))

Set the device

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Utility functions

In [5]:
# Loads a .tiff image using PIL.
# image_path: path to the image to be loaded.
# Returns an image with mode "L".
def load_image(image_path):

    # load the image with PIL
    image = Image.open(image_path)

    # convert it into a numpy array
    image_np = np.array(image)

    # max pixel value
    max_val = np.max(image_np)

    # new desired max value
    new_max_val = 255

    # normalize the pixels and convert their values into unsigned integers
    normalized_image = (image_np / max_val * new_max_val).astype(np.uint8)

    # return the greyscale image represented by the numpy array
    return Image.fromarray(normalized_image, mode="L")

In [6]:
# Function to compute the mean pixel value of images in folder_path.
# folder_path: path to the folder containing images.
# device: device to be used for the computation.
# Returns the mean value of pixels of images in folder_path.
def compute_mean(folder_path, device):

    # initialize the mean and the number of images
    mean = torch.tensor(0.0, device=device)
    n_images = torch.tensor(0, device=device)

    # iterate over each class folder
    for sub_dir_name in os.listdir(folder_path):

        # current sub directory
        curr_sub_dir = os.path.join(folder_path, sub_dir_name)

        # add to samples each tuple for the current class
        for file_name in os.listdir(curr_sub_dir):

            if file_name.endswith(".tiff"):

                # load the current image
                image = load_image(os.path.join(curr_sub_dir, file_name))

                # apply some other requested transformations
                transformations = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
                image = transformations(image).to(device)

                # update the mean and the number of images
                mean += image.mean()
                n_images += 1

    # return the mean over all images in the dataset
    return mean / n_images

In [7]:
# Function to compute the standard deviation of pixels of images in folder_path.
# folder_path: path to the folder containing images.
# device: device to be used for the computation.
# Returns the standard deviation of pixels of images in folder_path.
def compute_std(folder_path, mean, device):

    # initialize the variance and the number of images
    var = torch.tensor(0.0, device=device)
    n_images = torch.tensor(0, device=device)

    # iterate over each class folder
    for sub_dir_name in os.listdir(folder_path):

        # current sub directory
        curr_sub_dir = os.path.join(folder_path, sub_dir_name)

        # add to samples each tuple for the current class
        for file_name in os.listdir(curr_sub_dir):

            if file_name.endswith(".tiff"):

                # load the current image
                image = load_image(os.path.join(curr_sub_dir, file_name))

                # apply some other requested transformations
                transformations = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
                image = transformations(image).to(device)

                # update variance and counter
                var += ((image - mean) ** 2).mean()
                n_images += 1

    # return the std over all images in the dataset
    return torch.sqrt(var / n_images)

# Construct training set and validation set from training images

Class for creating a dataset object to correctly store and process images in the training folder

In [8]:
# number of images per class to be contained in the validation set
VAL_IMAGES = 20

# class to construct a dataset from existing files
class WoodDataset(Dataset):

    # Constructor.
    # root_dir: directory containing the subfolders (one for each class) with training images.
    # transformations: transformations to be applied to images in the dataset.
    # train: boolean variable that must be set to True if the user wants to create the
    #        training set from given images; False to create the validation set.
    def __init__(self, root_dir, transformations=None, train=True):

        # device
        self.device = "cpu"
        if torch.cuda.is_available():
            self.device = "cuda"

        # boolean variable specifying whether the training set or the validation
        # set has to be created
        self.train = train

        # images per class to be placed in the validation set
        self.n_test_per_class = VAL_IMAGES

        # root directory
        self.root_dir = root_dir

        # classes names and indices
        self.classes_labels = self.get_classes_labels()

        # (image_path, class_id) for each image the dataset
        self.items = self.get_samples()

        # transformations to be applied to samples in the dataset
        self.transform = transformations


    # number of items in the training set
    def __len__(self):
        return len(self.items)


    # Retrieves the tuple (image, label) for the image at position index in self.items.
    # index: position of the sample to be retrieved from self.items.
    # Returns (image, label), where image is the tensor representation for the image to retrieve and
    # label is its class index.
    def __getitem__(self, index):

        # extract path and class for the image at position index in the dataset
        path, label_id = self.items[index]

        # load the image using PIL
        image = self.load_image(path)

        # apply other requested transformations, if any
        if self.transform:
            image = self.transform(image).to(self.device)

        return image, label_id


    # creates a dictionary where each class folder name is a key and each value is an integer id
    def get_classes_labels(self):

        # dictionary that will contain folders names and indices
        classes = {}

        # names of the class labels
        classes_names = ["UNK", "s1", "s2", "s3", "s4", "s5", "s6"]

        # extract class names and assign an index to each of them
        for sub_dir in os.scandir(self.root_dir):

            # consider only directories
            if sub_dir.is_dir():

                # extract the name of the directory
                class_name = sub_dir.name

                # assign an id to each folder name, based on classes_names
                for i in range(len(classes_names)):
                    if classes_names[i] in class_name:
                        classes[class_name] = i

        return classes


    # loads the samples contained in the dataset as tuples (path_to_image, label)
    def get_samples(self):

        # list that will contain the tuples
        samples = []

        # iterate over each class folder
        for sub_dir_name in self.classes_labels.keys():

            # samples for the current class
            curr_samples = []

            # id of the current class
            class_id = self.classes_labels[sub_dir_name]

            # path to the directory where images of the current class are stored
            curr_class_dir = os.path.join(self.root_dir, sub_dir_name)

            # add to curr_samples each tuple for the current class
            for file_name in os.listdir(curr_class_dir):
                if file_name.endswith(".tiff"):
                    curr_samples.append((os.path.join(curr_class_dir, file_name), class_id))

            # if we need to construct the training set, add all images but the test ones
            if self.train:
                curr_samples = curr_samples[:-self.n_test_per_class]

            # else, add only test images
            else:
                curr_samples = curr_samples[-self.n_test_per_class:]

            # append to samples elements in curr_samples
            for sample in curr_samples:
                samples.append(sample)

        return samples


    # Loads a .tiff image using PIL.
    # image_path: path to the image to be loaded.
    # Returns the loaded image with mode "L".
    def load_image(self, image_path):

        # load the image with PIL
        image = Image.open(image_path)

        # convert it into a numpy array
        image_np = np.array(image)

        # max pixel value
        max_val = np.max(image_np)

        # new desired max value
        new_max_val = 255

        # normalize the pixels and convert their values into unsigned integers
        normalized_image = (image_np / max_val * new_max_val).astype(np.uint8)

        # return the greyscale image represented by the numpy array
        return Image.fromarray(normalized_image, mode="L")

Define paths and transformations

In [9]:
# path to the training folder
training_set_folder = os.path.join(os.getcwd(), "TRAINING")

# compute mean and std of images in the training folder
mean = compute_mean(os.path.join(os.getcwd(), "TRAINING"), device)
std = compute_std(os.path.join(os.getcwd(), "TRAINING"), mean, device)

# transformations to be applied to each image in the training folder
transformations = transforms.Compose([
                           transforms.Resize((224, 224)),
                           transforms.ToTensor(),
                           transforms.Normalize(mean=[mean], std=[std])
])

Load training set and validation set with corresponding dataloaders

In [10]:
# create a training set with 180 images per class
training_set = WoodDataset(root_dir=training_set_folder, transformations=transformations, train=True)

# create a validation set with the remaining 20 images per class
validation_set = WoodDataset(root_dir=training_set_folder, transformations=transformations, train=False)

# create a dataloader to wrap the training set
batch_size = 64
train_dataloader = DataLoader(training_set, batch_size=batch_size, shuffle=True)

# create a dataloader to wrap the validation set
batch_size = 64
val_dataloader = DataLoader(validation_set, batch_size=batch_size, shuffle=True)

# Training

Load a pre-trained model and modify it according to our needs

In [11]:
# load pre-trained ResNet50
resnet_model = models.resnet50(weights="ResNet50_Weights.DEFAULT")

# print the original architecture of resnet50 model
print(f"\Original model:\n{resnet_model}\n\n")

# modify the input layer to be able to feed the model with greyscale images
resnet_model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

# modify the output layer to fit with our number of classes
n_classes = 7
resnet_model.fc = nn.Linear(in_features=resnet_model.fc.in_features, out_features=n_classes, bias=True)

# print the modified version of the model
print(f"\nModified model:\n{resnet_model}")

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 144MB/s]

\Original model:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1




Function to print statistics on the validation set during training

In [12]:
# Function to print accuracy and average loss on the validation set.
# val_dataloader: dataloader containing the validation set.
# model: model to test.
# loss_function: loss function to be used.
# Prints accuracy and average loss in the validation set.
def val_model(val_dataloader, model, loss_function):

    # number of batches in the dataloader
    num_batches = len(val_dataloader)

    # set the model to evaluation mode
    model.eval()

    # cumulative value of the loss function for the entire validation set
    val_set_loss = 0

    # number of correctly classified samples
    correct = 0

    # ensure that no gradient is computed
    with torch.no_grad():

        # iterate through the batches
        for batch, labels in val_dataloader:

            # move batch and labels to the device
            batch = batch.to(device)
            labels = labels.to(device)

            # make predictions on the current batch
            predicted_labels = model(batch)

            # update the validation set loss by summing the loss for the current batch
            val_set_loss += loss_function(predicted_labels, labels).item()

            # update the number of correctly classified samples by adding those for the current batch
            correct += (predicted_labels.argmax(-1) == labels).to(device).sum().item()

    # compute the mean loss on the validation set by normalizing the sum of losses by the number of batches
    val_set_loss /= num_batches

    # compute accuracy on the validation set by dividing by the number of samples in the dataset
    accuracy = correct / len(val_dataloader.dataset) * 100

    # print test set average loss and accuracy
    print(f"Validation Set | Cross Entropy Loss: {val_set_loss: .3f} | Accuracy: {accuracy: .2f}%")

Function for a training iteration

In [13]:
# Function that trains a model for a single iteration.
# dataloader: dataloader that contains the training set where to train the model.
# model: model to train on dataloader.dataset.
# loss_function: loss function that we want to optimize.
# optimizer: algorithm to be used to train the model.
# lr_sched: learning rate scheduler for updating the learning rate during training.
def train_model(dataloader, model, loss_function, optimizer, lr_sched=None):

    # turn the model to train mode
    model.train()

    # initialize the sum of the losses for all batches
    total_loss = 0

    # initialize the number of correctly predicted labels
    correct_labels = 0

    # extract a batch at the time
    for batch, true_labels in dataloader:

        # move to the device both the current batch and its labels
        batch = batch.to(device)
        true_labels = true_labels.to(device)

        # set to 0 the gradient w.r.t. each parameter of the model
        optimizer.zero_grad()

        # forward pass on the current batch, i.e., output of the model for the current batch
        predicted_labels = model(batch)

        # value of the loss function for the computed predictions
        error = loss_function(predicted_labels, true_labels)

        # compute the gradient of the loss w.r.t. model's parameters
        error.backward()

        # update the model's parameters according to the computed gradient
        optimizer.step()

        # update the learning rate
        if lr_sched:
            lr_sched.step()

        # update the total loss after the computed predictions
        total_loss += error.item()

        # update the number of correctly updated labels
        _, predicted_classes = torch.max(predicted_labels, -1)
        correct_labels += (predicted_classes == true_labels).to(device).sum().item()

    # compute mean loss and accuracy
    mean_loss = total_loss / len(dataloader)
    accuracy = correct_labels / dataloader.dataset.__len__() * 100

    # print some some information, including the current loss value, for the training epoch just performed
    print(f"Training Set   | Cross Entropy Loss: {mean_loss: .3f} | Accuracy: {accuracy: .2f}%")

Function to train and evaluate a model for a given number of training epochs

In [14]:
# Function to train a model for a given number of epochs on the given dataloader.
# train_dataloader: dataloader containing the training set to be used to update the
#                   model's parameters.
# val_dataloader: dataloader containing the validation set to be used to print some
#                 statistics at each iteration on unseen data.
# model: model to be trained.
# loss_function: loss function to be optimized.
# optimizer: algorithm to be used to update model's parameters.
# lr_sched: learning rate scheduler for updating the learning rate during training.
# epochs: number of training iterations.
def train_loop(train_dataloader, val_dataloader, model, loss_function, optimizer, lr_sched=None, epochs=10):

    # train the model for the desired number of epochs
    for i in range(epochs):

        # print current epoch
        print(f"\nIteration {i}")

        # train the model on the training set and print some information
        train_model(train_dataloader, model, loss_function, optimizer, lr_sched)

        # print accuracy and mean loss on the validation set
        val_model(val_dataloader, model, loss_function)

    # print that training is finished
    print(f"\n\nTraining is finished.\n")

Train the model

In [15]:
# hyperparameters
learning_rate = 1e-3
regularization = 2e-05

# move the model to the device
resnet_model.to(device)

# use cross entropy as loss function
loss_function = nn.CrossEntropyLoss()

# INITIALIZE WEIGHTS OF NEW ADDED LAYERS WITH SOME ITERATIONS

# at the beginning, we allow the optimizer to update only parameters of the new layers
parameters_to_update = [
    {"params": resnet_model.conv1.parameters(), "lr": learning_rate},
    {"params": resnet_model.fc.parameters(), "lr": learning_rate}
]

# use Adam as optimizer
optimizer = optim.Adam(parameters_to_update, weight_decay=regularization)

# number of epochs
epochs = 5

# train the model for the desired number of epochs
train_loop(train_dataloader, val_dataloader, resnet_model, loss_function, optimizer, epochs=epochs)

# FINE-TUNE THE MODEL BY UPDATING ALL WEIGHTS

# use Adam as optimizer as before
optimizer = optim.Adam(resnet_model.parameters(), lr=learning_rate, weight_decay=regularization)

# number of epochs
epochs = 30

# train the model for the desired number of epochs
train_loop(train_dataloader, val_dataloader, resnet_model, loss_function, optimizer, epochs=epochs)


Iteration 0
Training Set   | Cross Entropy Loss:  1.925 | Accuracy:  17.78%
Validation Set | Cross Entropy Loss:  1.979 | Accuracy:  20.71%

Iteration 1
Training Set   | Cross Entropy Loss:  1.830 | Accuracy:  31.19%
Validation Set | Cross Entropy Loss:  2.035 | Accuracy:  15.71%

Iteration 2
Training Set   | Cross Entropy Loss:  1.819 | Accuracy:  29.21%
Validation Set | Cross Entropy Loss:  1.967 | Accuracy:  19.29%

Iteration 3
Training Set   | Cross Entropy Loss:  1.707 | Accuracy:  36.27%
Validation Set | Cross Entropy Loss:  1.852 | Accuracy:  35.00%

Iteration 4
Training Set   | Cross Entropy Loss:  1.621 | Accuracy:  42.70%
Validation Set | Cross Entropy Loss:  1.746 | Accuracy:  31.43%


Training is finished.


Iteration 0
Training Set   | Cross Entropy Loss:  1.135 | Accuracy:  56.19%
Validation Set | Cross Entropy Loss:  1.401 | Accuracy:  48.57%

Iteration 1
Training Set   | Cross Entropy Loss:  0.714 | Accuracy:  73.41%
Validation Set | Cross Entropy Loss:  1.195 | Accura

# Save the weights

In [16]:
# save the state dict with learnt weights and not the entire model
torch.save(resnet_model.state_dict(), "weights_30.pth")