In [None]:
%load_ext autoreload
%autoreload 2

#Run our best pretrained model

#### - Our best model was trained with `batch_size = 2048`, `num_epochs = 50`, initial `lr = 1e-3`

#### - Logs and visualizations (including embedding visualization by Tensorboard projector) of the our best pretrained model will be stored in ./runs/best/ and can be shown by Tensorboard 

In [None]:
import os
import torch
import numpy as np
import net
import viz
import data_generator
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter()


dg = data_generator.DataGenerator(root='./dataset', batch_size=2048)

ckp_path = os.path.join('models', 'best.pt')
model = net.Net()
model, _ = net.load_ckp(ckp_path, model)
test_descriptors = net.compute_descriptor(model, dg.test_loader)

hist_fig, pca_fig, tsne_fig, confusion_fig = viz.get_all_plots(model, dg, test_descriptors)

hist_fig.show()
pca_fig.show()
tsne_fig.show()
confusion_fig.show()


writer.add_figure("Histogram", hist_fig)
writer.add_figure("PCA", pca_fig)
writer.add_figure("t-SNE", tsne_fig)
writer.add_figure("Confusion Heatmap", confusion_fig)

writer.add_embedding(test_descriptors, metadata=dg.test_labels, tag='test_descriptors')
writer.close()

In [2]:
import data_generator
import train
import test

batch_size = 512
dg = data_generator.DataGenerator(root='./dataset', batch_size=batch_size)

#Train models from scratch

#### Note: after each epoch, one checkpoint of the model will be created and stored in ./models

In [None]:
train.run(dg, batch_size=batch_size, num_epochs=10, lr=8e-4)

### After the training is done, you can observe the logs and visualization created during the trainning by Tensorboard. If you wish to have Tensorboard visualizations for each epoch, you can run the following code cell.

# Create historgrams and embedding for all checkpoints in ./models to visualize in Tensorboard

In [None]:
test.run(dg)

Start creating logs for all checkpoints
best.pt
checkpoint1.pt
