## Train model for embedding

Try ViT based, also fine tune the last attention map

### Import libraries

In [1]:
import os
import glob
import json
from os.path import join

import torchvision.datasets as datasets
import torchvision.transforms as transforms
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import random

from ShellRec.data_utils.prepare_photos import get_img_graph, split_graph
from ShellRec.model import TurtleEmbdAttn 
from ShellRec.dataset import TurtlePair 
from ShellRec.train import train_embedding

In [2]:
random.seed(42)

### Preprocessing photo files

In [3]:
all_train = get_img_graph(path = "../dataset/BoxTurtle", drop_p=[0.99,0])
split_graph(all_train, save_path = "../dataset")
_ = get_img_graph(path = "../dataset/BoxTurtle_holdout", 
                  file_to_save = "../dataset/BoxTurtle_holdout.json")

### Prepare torch dataset

In [3]:
## image transformations
transform_train = timm.data.create_transform(384, is_training = True, 
                                   auto_augment = "rand-m9-mstd0.5")


transform_test = transforms.Compose([
        transforms.Resize((384,384)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5,.5,.5], 
                             std=[0.5, 0.5, 0.5])
])

train_set = TurtlePair(data_file='../dataset/train.json',transform=transform_train)
val_set = TurtlePair(data_file='../dataset/val.json',transform=transform_test)

# Set up datasets and dataloaders
train_loader = DataLoader( TurtlePair(data_file='../dataset/train.json', 
                                             transform=transform_train), 
                                             batch_size = 64)
val_loader = DataLoader( TurtlePair(data_file='../dataset/val.json', 
                                           transform=transform_test), 
                                           batch_size = 64)

### training

Get the model set up, use a pretrained model as backbone, and add a new head to it.

In [4]:
torch.hub.set_dir('../pretrained/')

In [5]:
#device = "cpu" #
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   

model = TurtleEmbdAttn() # use a vit backbone
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CosineEmbeddingLoss()
num_epochs = 10

Training loop

In [6]:
loss_list = train_embedding(model, optimizer, criterion, train_loader, val_loader, device, num_epochs,
      save_path = "../pretrained")

Epoch: 1/10


100%|██████████| 19/19 [10:53<00:00, 34.39s/it]


Epoch: 1/10, Training Loss: 0.3626


100%|██████████| 3/3 [01:23<00:00, 27.82s/it]


Epoch: 1/10, Validation loss: 0.85
Epoch: 2/10


100%|██████████| 19/19 [10:49<00:00, 34.20s/it]


Epoch: 2/10, Training Loss: 0.3021


100%|██████████| 3/3 [01:22<00:00, 27.48s/it]


Epoch: 2/10, Validation loss: 0.75
Epoch: 3/10


100%|██████████| 19/19 [10:54<00:00, 34.42s/it]


Epoch: 3/10, Training Loss: 0.3019


100%|██████████| 3/3 [01:23<00:00, 27.73s/it]


Epoch: 3/10, Validation loss: 0.66
Epoch: 4/10


100%|██████████| 19/19 [10:47<00:00, 34.09s/it]


Epoch: 4/10, Training Loss: 0.2100


100%|██████████| 3/3 [01:24<00:00, 28.05s/it]


Epoch: 4/10, Validation loss: 0.63
Epoch: 5/10


100%|██████████| 19/19 [10:56<00:00, 34.54s/it]


Epoch: 5/10, Training Loss: 0.2414


100%|██████████| 3/3 [01:24<00:00, 28.01s/it]


Epoch: 5/10, Validation loss: 0.62
Epoch: 6/10


100%|██████████| 19/19 [10:50<00:00, 34.24s/it]


Epoch: 6/10, Training Loss: 0.2517


100%|██████████| 3/3 [01:22<00:00, 27.63s/it]


Epoch: 6/10, Validation loss: 0.60
Epoch: 7/10


100%|██████████| 19/19 [10:53<00:00, 34.37s/it]


Epoch: 7/10, Training Loss: 0.2190


100%|██████████| 3/3 [01:23<00:00, 27.78s/it]


Epoch: 7/10, Validation loss: 0.55
Epoch: 8/10


100%|██████████| 19/19 [10:46<00:00, 34.05s/it]


Epoch: 8/10, Training Loss: 0.2153


100%|██████████| 3/3 [01:24<00:00, 28.23s/it]


Epoch: 8/10, Validation loss: 0.55
Epoch: 9/10


100%|██████████| 19/19 [10:48<00:00, 34.12s/it]


Epoch: 9/10, Training Loss: 0.1903


100%|██████████| 3/3 [01:24<00:00, 28.25s/it]


Epoch: 9/10, Validation loss: 0.53
Epoch: 10/10


100%|██████████| 19/19 [11:01<00:00, 34.80s/it]


Epoch: 10/10, Training Loss: 0.2061


100%|██████████| 3/3 [01:23<00:00, 27.80s/it]


Epoch: 10/10, Validation loss: 0.53
