# Damage detection and recognition training pipeline

### Define run parameters
The dataset should follow the VGGFace2/ImageNet-style directory layout. Modify data_dir to the location of the dataset on wish to finetune on.

In [42]:
#from facenet_pytorch import InceptionResnetV1, fixed_image_standardization, training
from functions.utils import fixed_image_standardization
from functions import training
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
import numpy as np
import os
# from PIL import Image
# import copy
# import random
from functions import SiameseInceptionResnetV1, siamese_dataset, pass_epoch

In [43]:
data_dir = "C:/Users/U339700/Documents/Palas/dataset/prueba_siamesas/crop_2class_train"
val_dir = "C:/Users/U339700/Documents/Palas/dataset/prueba_siamesas/crop_2class_val"
model_dir = 'models/'

batch_size = 32
epochs = 8
save_after = 1
workers = 0 if os.name == 'nt' else 8

### Determine if an nvidia GPU is available

In [44]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

Running on device: cpu


### Define Inception Resnet V1 module
See help(InceptionResnetV1) for more details.

Define dataset and dataloader

In [45]:
trans = transforms.Compose([
    transforms.Resize((256, 256)),
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])

data_augm_horz = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=1),
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])

data_augm_vert = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomVerticalFlip(p=1),
    transforms.ToTensor(),
    fixed_image_standardization
])

data_augm_persp = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomPerspective(distortion_scale=0.6, p=1.0),
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])

data_augm_eq = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomEqualize(p=1),
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])

data_augm_blur = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])

data_augm_horz_vert = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=1),
    transforms.RandomVerticalFlip(p=1),
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])

data_augm_horz_blur = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=1),
    transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])

data_augm_vert_blur = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomVerticalFlip(p=1),
    transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])

data_augm_horz_vert_blur = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=1),
    transforms.RandomVerticalFlip(p=1),
    transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])

In [46]:
train_dataset = siamese_dataset(data_dir, shuffle_pairs=True)
val_dataset = siamese_dataset(val_dir, shuffle_pairs=False)

In [47]:
train_loader = DataLoader(train_dataset, batch_size=32)
val_loader = DataLoader(val_dataset, batch_size=32)

In [48]:
siamese = SiameseInceptionResnetV1(
    classify=False,
    pretrained='vggface2',
).to(device)

Define optimizer and scheduler

In [51]:
optimizer = optim.Adam(siamese.parameters(), lr=0.001)
scheduler = MultiStepLR(optimizer, [5, 10])

### Define loss and evaluation functions

In [52]:
# loss_fn = torch.nn.CrossEntropyLoss()
loss_fn = torch.nn.BCELoss()
metrics = {
    'fps': training.BatchTimer(),
    'acc': training.accuracy
}

### Train model

In [61]:
writer = SummaryWriter()
writer.iteration, writer.interval = 0, 10
best_val = 10000000

print('\n\nInitial')
print('-' * 10)
siamese.eval()
pass_epoch(
    siamese, loss_fn, val_loader,
    batch_metrics=metrics, show_running=True, device=device,
    writer=writer
)

for epoch in range(epochs):
    print('\nEpoch {}/{}'.format(epoch + 1, epochs))
    print('-' * 10)

    # Train
    siamese.train()
    pass_epoch(
        siamese, loss_fn, train_loader, optimizer, scheduler,
        batch_metrics=metrics, show_running=True, device=device,
        writer=writer
    )

    # Validation
    siamese.eval()
    answ = pass_epoch(
        siamese, loss_fn, val_loader,
        batch_metrics=metrics, show_running=True, device=device,
        writer=writer
    )
    loss = answ[0]
    # Save model
    if loss < best_val:
        best_val = loss
        print('Saving best weights')
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state_dict": siamese.state_dict(),
                # "backbone": args.backbone,
                "optimizer_state_dict": optimizer.state_dict()
            },
            os.path.join(model_dir, "best.pth")
        )            

    # Save model based on the frequency defined by "args.save_after"
    if (epoch + 1) % 2 == 0:
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state_dict": siamese.state_dict(),
                # "backbone": args.backbone,
                "optimizer_state_dict": optimizer.state_dict()
            },
            os.path.join(model_dir, "epoch_{}.pth".format(epoch + 1))
        ) 

writer.close()



Initial
----------

Epoch 1/8
----------
Saving best weights

Epoch 2/8
----------

Epoch 3/8
----------

Epoch 4/8
----------

Epoch 5/8
----------

Epoch 6/8
----------

Epoch 7/8
----------

Epoch 8/8
----------


In [62]:
torch.save(siamese, 'models/nuevo.pth')