In [None]:
!pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
!pip install matplotlib scikit-learn pytorch-lightning

In [None]:
!unzip ImageLibrary_6_11_19.zip

In [63]:
import os
from typing import (
    List, Any, Callable, 
    Optional, Tuple, Union, Dict)

import numpy as np
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from sklearn.metrics import precision_recall_fscore_support
from sklearn.model_selection import train_test_split
import torch
from torch import nn
import torchvision.models as models
from torchvision import transforms

In [64]:
import warnings
warnings.filterwarnings('ignore')

# Initial setup

In [65]:
np.random.seed(42)

data_path = './ImageLibrary_6_11_19'

training_fraction = 0.7
validation_fraction = 0.2
test_fraction = 0.1

# Main classes and helper functions

In [66]:
class ImageFilelistDataset(torch.utils.data.Dataset):
    def __init__(self, 
            image_paths: List[str], 
            labels: List[int],
            transform: Optional[Callable] = None,
            test: bool = False
        ):
        self.img_paths = image_paths
        self.labels = labels
        self.transform = transform or transforms.ToTensor()
        self.test = test

    def _loader(self, path: str) -> Any:
        return Image.open(path).convert('RGB')

    def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[torch.Tensor, int]]:
        img_path = self.img_paths[index]
        img_rgb = self._loader(img_path)
        img = self.transform(img_rgb)
        img_rgb.close()

        if self.test:
            return img

        target = self.labels[index]
        return img, target

    def __len__(self) -> int:
        return len(self.img_paths)

In [41]:
class ImageDataModule(pl.LightningDataModule):
    def __init__(
            self, 
            data_path: str, 
            batch_size: int = 32, 
            img_transforms: Optional[Callable] = None
        ):
        super().__init__()

        self.data_path = data_path
        self.batch_size = batch_size
        self.transforms = img_transforms
    
    def _load_paths(self) -> Tuple[List[str], List[int]]:
        """
        Finds paths of images of each class 
        and randomly selects the same number of images of each class.
        Thus, the number of samples of each class is the same.
        Returns a list with image paths and a list of class labels.
        """
        class_dirs = os.listdir(self.data_path)

        min_class_imgs_count = min([
            len(os.listdir(os.path.join(data_path, class_dir))) 
            for class_dir in class_dirs
        ])

        data_paths = []
        classes = []

        for i, class_dir in enumerate(class_dirs):
            class_files = os.listdir(os.path.join(data_path, class_dir))
            class_files_random = np.random.choice(
                class_files, 
                size=min_class_imgs_count, 
                replace=False
            )
            data_paths.extend([
                os.path.join(data_path, class_dir, class_file)
                for class_file in class_files_random
            ])
            classes.extend([i]*min_class_imgs_count) 
        return data_paths, classes
    
    def _data_split(self, data_paths: List[str], classes: List[int]):
        self.X_train, X_rem, self.y_train, y_rem = train_test_split(
            data_paths, classes, 
            test_size=validation_fraction + test_fraction, 
            stratify=classes
        )
        self.X_val, self.X_test, self.y_val, self.y_test = train_test_split(
            X_rem, y_rem, 
            test_size=test_fraction / (validation_fraction+test_fraction), 
            stratify=y_rem
        )

    def setup(self, stage: Optional[str] = None):
        data_paths, classes = self._load_paths()
        self._data_split(data_paths, classes)

    def train_dataloader(self):
        data = ImageFilelistDataset(
            self.X_train, self.y_train, transform=self.transforms['train'])
        return torch.utils.data.DataLoader(
            dataset=data,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=8
        )

    def val_dataloader(self):
        data = ImageFilelistDataset(
            self.X_val, self.y_val, transform=self.transforms['default'])
        return torch.utils.data.DataLoader(
            dataset=data,
            batch_size=self.batch_size,
            num_workers=8
        )

    def test_dataloader(self):
        data = ImageFilelistDataset(
            self.X_test, self.y_test, transform=self.transforms['default'])
        return torch.utils.data.DataLoader(
            dataset=data,
            batch_size=self.batch_size,
            num_workers=8
        )

    def predict_dataloader(self):
        data = ImageFilelistDataset(
            self.X_test, self.y_test, transform=self.transforms['default'], test=True)
        return torch.utils.data.DataLoader(
            dataset=data,
            num_workers=8
        )

In [93]:
class ResNet50(pl.LightningModule):
    def __init__(self, num_target_classes: int):
        super().__init__()

        self.model = self._build_model(num_target_classes)
        self.num_target_classes = num_target_classes

        self.loss = nn.CrossEntropyLoss()

    def _build_model(self, num_classes: int) -> models.resnet.ResNet:
        model = models.resnet50(pretrained=True)
        model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
        return model
    
    def _log_metric(self, metric: str, value: float):
        self.log(metric, value, prog_bar=True, on_epoch=True, on_step=False)

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx: int): 
        loss, acc = self._step(batch)
        self._log_metric("train_loss", loss)
        self._log_metric("train_acc", acc)      
        return loss

    def validation_step(self, batch, batch_idx: int): 
        loss, acc = self._step(batch)
        self._log_metric("val_loss", loss)
        self._log_metric("val_acc", acc)    

    def test_step(self, batch, batch_idx: int): 
        loss, acc = self._step(batch)
        self._log_metric("test_loss", loss)
        self._log_metric("test_acc", acc)               
    
    def _step(self, batch):
        x, y = batch
        y_pred = self.forward(x)
        acc = self.acc(y_pred, y)
        loss = self.loss(y_pred, y)
        return loss, acc

    def acc(self, y_pred, y_target):
        return (y_target == torch.argmax(y_pred, 1)).type(torch.FloatTensor).mean()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

In [94]:
def get_binary_metrics(
        y_true: List[int], 
        y_pred: List[torch.Tensor], 
        target_class: int = 3) -> Dict[str, float]:
    """
    Converts the results to a binary form according to `target_class`. 
    Calculates the precision, recall and F1 score.
    """
    y_true = np.array(y_true)
    y_true[y_true != target_class] = 0
    y_true[y_true == target_class] = 1

    y_pred = [v[0] for v in y_pred]
    y_pred = torch.argmax(torch.stack(y_pred), dim=1)
    y_pred[y_pred != target_class] = 0
    y_pred[y_pred == target_class] = 1

    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred.cpu(), average='binary')
    return {
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }

# Training

In [95]:
img_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomAffine(
            degrees=(-180, 180), 
            translate=(0.228, 0.228) # Translate in pixels is 175*0.228 = 39.9 px
        ),
        transforms.RandomVerticalFlip(),
        transforms.RandomHorizontalFlip(),
        transforms.Resize((224, 224))
    ]),
    'default': transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224))
    ])
}

In [96]:
data_module = ImageDataModule(data_path, img_transforms=img_transforms)
data_module.setup()
num_target_classes = len(os.listdir(data_path))
target_class = os.listdir(data_path).index('trophallaxis')

In [97]:
model = ResNet50(num_target_classes)
trainer = pl.Trainer(
    gpus=1, max_epochs=20, 
    callbacks=[EarlyStopping('val_loss', patience=5)]
)
trainer.fit(model, data_module)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | model | ResNet           | 23.5 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.065    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

# Result metrics

## Train loss and accuracy

In [102]:
trainer.validate(model, data_module.train_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validating: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.9464057683944702, 'val_loss': 0.1531306803226471}
--------------------------------------------------------------------------------


[{'val_loss': 0.1531306803226471, 'val_acc': 0.9464057683944702}]

## Validation loss and accuracy

In [101]:
trainer.validate(model, data_module.val_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validating: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.913690447807312, 'val_loss': 0.22848735749721527}
--------------------------------------------------------------------------------


[{'val_loss': 0.22848735749721527, 'val_acc': 0.913690447807312}]

## Test loss and accuracy

In [103]:
trainer.validate(model, data_module.test_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validating: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.9317507147789001, 'val_loss': 0.23040859401226044}
--------------------------------------------------------------------------------


[{'val_loss': 0.23040859401226044, 'val_acc': 0.9317507147789001}]

## Test binary classification metrics

In [105]:
y_pred = trainer.predict(model, data_module.predict_dataloader())
get_binary_metrics(data_module.y_test, y_pred)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 74it [00:00, ?it/s]

{'precision': 1.0, 'recall': 0.9285714285714286, 'f1_score': 0.962962962962963}

# Save models

In [106]:
torch.save(model.model, 'resnet50.pt')

In [107]:
trainer.save_checkpoint('resnet50.ckpt')

# Check models loading

In [108]:
loaded_model = torch.load('resnet50.pt')
loaded_model.eval()

new_model = ResNet50(num_target_classes)
new_model.model = loaded_model

y_pred = trainer.predict(new_model, data_module.predict_dataloader())
get_binary_metrics(data_module.y_test, y_pred)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 74it [00:00, ?it/s]

{'precision': 1.0, 'recall': 0.9285714285714286, 'f1_score': 0.962962962962963}

In [109]:
loaded_model = ResNet50.load_from_checkpoint('resnet50.ckpt', num_target_classes=num_target_classes)
y_pred = trainer.predict(loaded_model, data_module.predict_dataloader())
get_binary_metrics(data_module.y_test, y_pred)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 74it [00:00, ?it/s]

{'precision': 1.0, 'recall': 0.9285714285714286, 'f1_score': 0.962962962962963}

# Binary metrics for the whole dataset

In [110]:
full_data = data_module.X_train + data_module.X_val + data_module.X_test
full_target = data_module.y_train + data_module.y_val + data_module.y_test

In [111]:
data = ImageFilelistDataset(full_data, full_target, transform=img_transforms['default'], test=True)
full_dataloader = torch.utils.data.DataLoader(
    dataset=data,
    num_workers=8
)

In [112]:
y_pred = trainer.predict(model, full_dataloader)
get_binary_metrics(full_target, y_pred)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 74it [00:00, ?it/s]

{'precision': 0.9987029831387808,
 'recall': 0.9166666666666666,
 'f1_score': 0.9559279950341403}