## Train model identify if two images are the same turtle

### Import libraries

In [1]:
import os
import glob
import json
from os.path import join

import torchvision.datasets as datasets
import torchvision.transforms as transforms
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import random

from ShellRec.data_utils.prepare_photos import get_img_graph, split_graph
from ShellRec.model import TurtleDiffPool 
from ShellRec.dataset import TurtlePair 
from ShellRec.train import train

In [2]:
random.seed(42)

### Preprocessing photo files

In [3]:
all_train = get_img_graph(path = "../dataset/BoxTurtle", drop_p=[0.99,0])
split_graph(all_train, save_path = "../dataset")
_ = get_img_graph(path = "../dataset/BoxTurtle_holdout", 
                  file_to_save = "../dataset/BoxTurtle_holdout.json")

### Prepare torch dataset

In [3]:
## image transformations
transform_train = timm.data.create_transform(384, is_training = True, 
                                   auto_augment = "rand-m9-mstd0.5")


transform_test = transforms.Compose([
        transforms.Resize((384,384)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5,.5,.5], 
                             std=[0.5, 0.5, 0.5])
])

train_set = TurtlePair(data_file='../dataset/train.json',transform=transform_train)
val_set = TurtlePair(data_file='../dataset/val.json',transform=transform_test)

# Set up datasets and dataloaders
train_loader = DataLoader( TurtlePair(data_file='../dataset/train.json', 
                                             transform=transform_train), 
                                             batch_size = 64)
val_loader = DataLoader( TurtlePair(data_file='../dataset/val.json', 
                                           transform=transform_test), 
                                           batch_size = 64)

### training

Get the model set up, use a pretrained model as backbone, and add a new head to it.

In [4]:
torch.hub.set_dir('../pretrained/')

In [5]:
#device = "cpu" #
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   

model = TurtleDiffPool('vit_base_patch16_384') # use a vit backbone
#model = TurtleDiff('resnet50') # use a resnet backbone
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
num_epochs = 20

Training loop

In [6]:
loss_list = train(model, optimizer, criterion, train_loader, val_loader, device, num_epochs,
      save_path = "../pretrained")

Epoch: 1/20


100%|██████████| 20/20 [10:42<00:00, 32.12s/it]


Epoch: 1/20, Training Loss: 0.7462


100%|██████████| 3/3 [01:20<00:00, 27.00s/it]


Epoch: 1/20, Validation Accuracy: 54.61%
Epoch: 2/20


100%|██████████| 20/20 [10:37<00:00, 31.86s/it]


Epoch: 2/20, Training Loss: 0.6410


100%|██████████| 3/3 [01:24<00:00, 28.12s/it]


Epoch: 2/20, Validation Accuracy: 78.29%
Epoch: 3/20


100%|██████████| 20/20 [10:40<00:00, 32.04s/it]


Epoch: 3/20, Training Loss: 0.3997


100%|██████████| 3/3 [01:21<00:00, 27.19s/it]


Epoch: 3/20, Validation Accuracy: 80.92%
Epoch: 4/20


100%|██████████| 20/20 [10:41<00:00, 32.07s/it]


Epoch: 4/20, Training Loss: 0.4753


100%|██████████| 3/3 [01:23<00:00, 27.92s/it]


Epoch: 4/20, Validation Accuracy: 84.21%
Epoch: 5/20


100%|██████████| 20/20 [10:40<00:00, 32.04s/it]


Epoch: 5/20, Training Loss: 0.2396


100%|██████████| 3/3 [01:21<00:00, 27.29s/it]


Epoch: 5/20, Validation Accuracy: 82.24%
Epoch: 6/20


100%|██████████| 20/20 [10:43<00:00, 32.17s/it]


Epoch: 6/20, Training Loss: 0.4424


100%|██████████| 3/3 [01:26<00:00, 28.89s/it]


Epoch: 6/20, Validation Accuracy: 82.24%
Epoch: 7/20


100%|██████████| 20/20 [10:37<00:00, 31.89s/it]


Epoch: 7/20, Training Loss: 0.1193


100%|██████████| 3/3 [01:24<00:00, 28.22s/it]


Epoch: 7/20, Validation Accuracy: 86.18%
Epoch: 8/20


100%|██████████| 20/20 [10:42<00:00, 32.15s/it]


Epoch: 8/20, Training Loss: 0.1061


100%|██████████| 3/3 [01:22<00:00, 27.33s/it]


Epoch: 8/20, Validation Accuracy: 87.50%
Epoch: 9/20


100%|██████████| 20/20 [10:46<00:00, 32.30s/it]


Epoch: 9/20, Training Loss: 0.0234


100%|██████████| 3/3 [01:22<00:00, 27.35s/it]


Epoch: 9/20, Validation Accuracy: 86.84%
Epoch: 10/20


100%|██████████| 20/20 [10:41<00:00, 32.06s/it]


Epoch: 10/20, Training Loss: 0.0565


100%|██████████| 3/3 [01:24<00:00, 28.08s/it]


Epoch: 10/20, Validation Accuracy: 85.53%
Epoch: 11/20


100%|██████████| 20/20 [09:54<00:00, 29.71s/it]


Epoch: 11/20, Training Loss: 0.0201


100%|██████████| 3/3 [01:11<00:00, 23.81s/it]


Epoch: 11/20, Validation Accuracy: 89.47%
Epoch: 12/20


100%|██████████| 20/20 [08:39<00:00, 25.96s/it]


Epoch: 12/20, Training Loss: 0.0436


100%|██████████| 3/3 [01:07<00:00, 22.62s/it]


Epoch: 12/20, Validation Accuracy: 90.13%
Epoch: 13/20


100%|██████████| 20/20 [08:57<00:00, 26.86s/it]


Epoch: 13/20, Training Loss: 0.0174


100%|██████████| 3/3 [01:16<00:00, 25.66s/it]


Epoch: 13/20, Validation Accuracy: 90.13%
Epoch: 14/20


100%|██████████| 20/20 [08:51<00:00, 26.56s/it]


Epoch: 14/20, Training Loss: 0.0090


100%|██████████| 3/3 [01:08<00:00, 22.87s/it]


Epoch: 14/20, Validation Accuracy: 89.47%
Epoch: 15/20


100%|██████████| 20/20 [08:42<00:00, 26.12s/it]


Epoch: 15/20, Training Loss: 0.3118


100%|██████████| 3/3 [01:07<00:00, 22.66s/it]


Epoch: 15/20, Validation Accuracy: 92.11%
Epoch: 16/20


100%|██████████| 20/20 [08:47<00:00, 26.39s/it]


Epoch: 16/20, Training Loss: 0.1265


100%|██████████| 3/3 [01:09<00:00, 23.09s/it]


Epoch: 16/20, Validation Accuracy: 92.11%
Epoch: 17/20


100%|██████████| 20/20 [08:47<00:00, 26.36s/it]


Epoch: 17/20, Training Loss: 0.0369


100%|██████████| 3/3 [01:09<00:00, 23.31s/it]


Epoch: 17/20, Validation Accuracy: 91.45%
Epoch: 18/20


100%|██████████| 20/20 [11:46<00:00, 35.34s/it]


Epoch: 18/20, Training Loss: 0.2280


100%|██████████| 3/3 [01:29<00:00, 29.96s/it]


Epoch: 18/20, Validation Accuracy: 92.11%
Epoch: 19/20


100%|██████████| 20/20 [11:42<00:00, 35.13s/it]


Epoch: 19/20, Training Loss: 0.0662


100%|██████████| 3/3 [01:32<00:00, 30.83s/it]


Epoch: 19/20, Validation Accuracy: 90.13%
Epoch: 20/20


100%|██████████| 20/20 [11:39<00:00, 34.98s/it]


Epoch: 20/20, Training Loss: 0.0331


100%|██████████| 3/3 [01:31<00:00, 30.42s/it]

Epoch: 20/20, Validation Accuracy: 92.11%



