In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from tensorboardX import SummaryWriter
import torch

from src import data_prep, dataset, train, visualize
from src.loss import ContrastiveLoss
from src.models import segnet

### Read and process the data

In [None]:
df = data_prep.get_df_from_folder('/home/anuj/code/data/lfw_train')
df_train, df_val = data_prep.split_train_val(df)

In [None]:
np.sum(df_train.groupby('label').count() > 1), np.sum(df_val.groupby('label').count() > 1)

### Get the dataset and dataloader

In [None]:
%%time
dataset_train, dataloader_train = dataset.get_dataloader(df_train, image_side=160, batch_size=4*24, num_workers=8)
dataset_val, dataloader_val = dataset.get_dataloader(df_val, image_side=160, batch_size=4*24, num_workers=8)

### Visualize

In [None]:
for ix, batch in enumerate(dataloader_val):
    if ix >= 2:
        break
    visualize.visualize(batch, 5)

### Set up model, optimizer, loss function

In [None]:
device_id = 2

In [None]:
# Model, Optimizer, Loss
model = segnet.SiameseNetworkLarge(160)
model = torch.nn.DataParallel(model, device_ids=[2, 3]).cuda(device_id)

optimizer = torch.optim.Adam(model.parameters())
loss_func = ContrastiveLoss().cuda(device_id)

### Set up logging

In [None]:
model_str = 'face-siamese-contrastive-3.04'
weights_folder = f"/home/anuj/weights/{model_str}"
writer = SummaryWriter(weights_folder) # writing log to tensorboard
print('logging to: {}'.format(weights_folder))

### Train

In [None]:
train.run_training_loop(
    model=model,
    dataloader_train=dataloader_train,
    dataloader_val=dataloader_val,
    loss_func=loss_func,
    optimizer=optimizer,
    writer=writer,
    device=device_id,
    weights_folder=weights_folder,
)