In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
import torch
from torchvision.datasets import MNIST, FashionMNIST
from torchvision import transforms
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.autograd import Variable

In [None]:
mean, std = 0.1307, 0.3081

train_dataset = FashionMNIST('../data/FashinMNIST', train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((mean,), (std,))
                             ]))
test_dataset = FashionMNIST('../data/FashionMNIST', train=False, download=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((mean,), (std,))
                            ]))
n_classes = 10

In [None]:
batch_size = 128
cuda = torch.cuda.is_available()
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, **kwargs)

In [None]:
%matplotlib inline
from genereateTriplets import TripletMNIST, CombinedTriplet


In [None]:
triplet_train_dataset = TripletMNIST(train_dataset) # Returns triplets of images
triplet_test_dataset = TripletMNIST(test_dataset)


In [None]:
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
triplet_train_loader = torch.utils.data.DataLoader(triplet_train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
triplet_test_loader = torch.utils.data.DataLoader(triplet_test_dataset, batch_size=batch_size, shuffle=False, **kwargs)

In [None]:
from models import EncoderNet, TripletNet
from loss import TripletLoss

In [None]:
encoder = EncoderNet()
model = TripletNet(encoder)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

In [None]:
loss_fn = TripletLoss(margin=1)
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = 20
log_interval = 100

In [None]:
from train import fit
fit(triplet_train_loader, triplet_test_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval)

In [None]:
from plotExtract import plot_features, extract_features

train_features, train_labels = extract_features(train_loader, model)
plot_features(train_features, train_labels)
val_features, val_labels = extract_features(test_loader, model)
plot_features(val_features, val_labels)

In [None]:
import time
timestr = time.strftime("%Y%m%d-%H%M%S")

In [None]:
enc_weights_file = f'./saved_models/fmnist_enc_save'+timestr
trip_weights_file = f'./saved_models/fmnist_trip_save'+timestr

In [None]:
torch.save(encoder.state_dict(), enc_weights_file)
torch.save(model.state_dict(), trip_weights_file)