In [1]:
from PIL import Image
import numpy as np
import os

import torch
from torch.utils.data import Dataset
import torch.optim as optim
import torch.nn as nn
from torchvision import transforms

from sklearn.metrics import classification_report, fbeta_score

from tqdm import tqdm
import torch.nn.functional as F

In [2]:
# Define hyperparameters
BATCH_SIZE = 16
LEARNING_RATE = 0.05

In [3]:
# Set the device
device = "cpu"
if torch.cuda.is_available():
    device =  "cuda"
elif torch.backends.mps.is_available():
    device = "mps" # Use M1 Mac GPU
print(device)

mps


In [5]:
class ExtractorClassifierDataset(Dataset):
    def __init__(self, data, resize_shape = (256,256)):
        self.data = data
        self.to_tensor = transforms.ToTensor()
        self.resize_image = transforms.Resize(resize_shape, antialias=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, label = self. data[idx]
        image_tensor = self.to_tensor(image)
        image_tensor = self.resize_image(image_tensor)
        image_tensor = image_tensor/255
        label_tensor = torch.zeros(2)
        label_tensor[label] = 1
        return image_tensor, label_tensor

data_dir = "../../../datasets/extractor_classifier"

train_list = []
validation_list = []
test_list = []

for class_name in ("relevant", "not_relevant"):
    if class_name == "relevant":
        label = 0
    else:
        label = 1

    for image in os.listdir(f"{data_dir}/train/{class_name}"):
        if image != ".ipynb_checkpoints":
            src_path = f"{data_dir}/train/{class_name}/{image}"
            train_list.append((Image.open(src_path), label))

    for image in os.listdir(f"{data_dir}/validation/{class_name}"):
        if image != ".ipynb_checkpoints":
            src_path = f"{data_dir}/validation/{class_name}/{image}"
            validation_list.append((Image.open(src_path), label))

    for image in os.listdir(f"{data_dir}/test/{class_name}"):
        if image != ".ipynb_checkpoints":
            src_path = f"{data_dir}/test/{class_name}/{image}"
            test_list.append((Image.open(src_path), label))

# Define the train/validation/test datasets
train_data = ExtractorClassifierDataset(train_list)
validation_data = ExtractorClassifierDataset(validation_list)
test_data = ExtractorClassifierDataset(test_list)

In [6]:
# Define the dataloaders for each dataset
train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, )
validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)

In [7]:
from torchvision.models import resnet50

class Resnet50Model(nn.Module):
    def __init__(self, pretrained=False):
        super(Resnet50Model, self).__init__()
        self.resnet_model = resnet50(pretrained=pretrained)
        num_features = self.resnet_model.fc.in_features
        self.resnet_model.fc = nn.Linear(num_features, 2)
        
    def forward(self, x):
        x = self.resnet_model(x)
        return x

In [10]:
def train(model, model_path, train_loader, validation_loader, test_loader, criterion, optimizer, epochs):
    # Move the model to the device
    model.to(device)

    # Define variables to track the best validation accuracy and the corresponding model state
    best_valid_f_beta = 0.0
    train_loss_values = []
    validation_loss_values = []
    train_f_beta_values = []
    validation_f_beta_values = []

    train_true_labels = []
    train_predicted_labels = []

    for epoch in range(epochs):
        # Train the model for one epoch
        train_loss = 0.0
        model.train()
        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            train_loss += loss.item() * images.size(0)
            loss.backward()
            optimizer.step()

            outputs = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)
            _, labels = torch.max(labels.data, 1)

            train_predicted_labels.extend(predicted.detach().cpu())
            train_true_labels.extend(labels.detach().cpu())

        train_loss = train_loss / len(train_loader.dataset)
        train_loss_values.append(train_loss)

        train_f_beta = fbeta_score(train_true_labels, train_predicted_labels, average="binary", beta=0.5)
        train_f_beta_values.append(train_f_beta)

        print(f"Epoch {epoch+1}/{epochs} - "
              f"Train Loss: {train_loss:.4f} - Train Fbeta: {train_f_beta:.4f}")

        # Evaluate the model on the validation set
        validation_loss, validation_f_beta = validate_model(model, validation_loader)
        validation_loss_values.append(validation_loss)
        validation_f_beta_values.append(validation_f_beta)

        # Save the model state if the current validation accuracy is better than the previous best
        if validation_f_beta > best_valid_f_beta:
            best_valid_f_beta = validation_f_beta
            print(f"Updated best model: epoch {epoch+1}")
            torch.save(model.state_dict(), model_path)

    # Load the best model state and evaluate on the test set
    model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
    print("Test model:")
    test_loss, test_f_beta = validate_model(model, test_loader)
    print(f"Test Loss: {test_loss:.4f} - Test Fbeta: {test_f_beta:.4f}")

    return model, train_loss_values, train_f_beta_values, validation_loss_values, validation_f_beta_values

def validate_model(model, data_loader):
    model.to(device)
    model.eval()

    valid_loss = 0.0
    valid_true_labels = []
    valid_predicted_labels = []

    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            loss = criterion(outputs, labels)
            valid_loss += loss.item() * images.size(0)

            outputs = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)
            _, labels = torch.max(labels.data, 1)

            valid_predicted_labels.extend(predicted.detach().cpu())
            valid_true_labels.extend(labels.detach().cpu())

    f_beta_score = fbeta_score(valid_true_labels, valid_predicted_labels, average="binary", beta=0.5)
    valid_loss = valid_loss / len(data_loader.dataset)

    print(f"Valid Loss: {valid_loss:.4f} - Valid Fbeta: {f_beta_score:.4f}")
    print(classification_report(valid_true_labels, valid_predicted_labels))
    return valid_loss, f_beta_score

In [14]:
resnet_model = Resnet50Model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet_model.parameters(), lr=LEARNING_RATE)
model_path = "model_results/resnet50_model_50.pt"

if os.path.isfile(model_path):
    print("Evaluate model")
    resnet_model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
    print("Validation data:")
    validate_model(resnet_model, validation_loader)
    print("Test data:")
    validate_model(resnet_model, test_loader)
else:
    print("Train model")
    model, train_loss_values, train_f_beta_values, validation_loss_values, validation_f_beta_values = train(resnet_model, model_path, train_loader, validation_loader, test_loader, criterion, optimizer, 50)
    np.save("model_results/resnet50_model_50_train_loss", train_loss_values)
    np.save("model_results/resnet50_model_50_train_f_beta", train_f_beta_values)
    np.save("model_results/resnet50_model_50_valid_loss", validation_loss_values)
    np.save("model_results/resnet50_model_50_valid_f_beta", validation_f_beta_values)



Evaluate model
Validation data:
Valid Loss: 0.3440 - Valid Fbeta: 0.8021
              precision    recall  f1-score   support

           0       0.92      0.98      0.95       208
           1       0.86      0.64      0.73        47

    accuracy                           0.91       255
   macro avg       0.89      0.81      0.84       255
weighted avg       0.91      0.91      0.91       255

Test data:
Valid Loss: 0.2900 - Valid Fbeta: 0.8355
              precision    recall  f1-score   support

           0       0.92      0.98      0.95       418
           1       0.90      0.65      0.75        97

    accuracy                           0.92       515
   macro avg       0.91      0.82      0.85       515
weighted avg       0.92      0.92      0.92       515

