In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import resnet50  # Pretrained network
import torchvision.transforms as T
from model import SiameseNetwork
from loss import ContrastiveLoss
from data import SiameseDataset, Grayscale
from utils import train_fn
from pathlib import Path
import matplotlib.pyplot as plt

plt.style.use("ggplot")
% matplotlib inline

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")
# training, network configuration
transform = T.Compose([
    T.Resize(257),
    T.CenterCrop(256),
    T.ToTensor(),
    Grayscale()
])

In [None]:
# Directory structure
# root
# | -> Samples-1
# | -> Samples-2
root = Path("")  # point towards the path of the dataset


In [None]:
emb_dim = 1024
bs = 16
n_epochs = 10
lr = 1e-3
alpha = 0.25
freeze = True

In [None]:
siamese_ds = SiameseDataset(root / "sketches", root / "photos", transform=transform)
siamese_dl = DataLoader(siamese_ds, bs, shuffle=True)

encoder_network = resnet50(True)
network = SiameseNetwork(encoder_network=encoder_network, emb_dim=emb_dim, rate=0.6, freeze=freeze).to(device)
optimizer = optim.Adam(network.parameters(), lr=lr)
loss_fn = ContrastiveLoss(alpha=alpha, device=device)

In [None]:
# Uncomment to run the network
losses = []
for _ in range(n_epochs):
    loss = train_fn(network, loss_fn, optimizer, siamese_dl, device)
    losses.append(loss)

In [None]:
# Loss Vs. Epochs
plt.plot(losses)

# Checking network on the train set

In [None]:
train_sketches, train_photos = [], []
for sketches, photos in siamese_dl:
    train_sketches.append(sketches)
    train_photos.append(photos)

train_sketches = torch.stack(train_sketches, dim=0)
train_photos = torch.stack(train_photos, dim=0)

In [None]:
# image_embeddings = network.encode_samples(train_photos)
# image_embeddings = image_embeddings / torch.norm(image_embeddings)
# index = 0
#
# sketch = train_sketches[index]
# if sketch.ndim == 3:
#     sketch = sketch.unsqueeze(0)
#
# sketch_embedding = network.encode_samples(sketch)
# sketch_embedding = sketch_embedding / torch.norm(sketch_embedding)