In [2]:
# Supress warnings
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

# Python
import matplotlib.pyplot as plt

# NumPy and PyTorch
import torch
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader

# Custom
from path_loader import PathDataLoader
from networks import SiameseNetworkSimple, SiameseNetworkComplex
from losses import ContrastiveLoss
from patch_generator import PatchGenerator, normalize_patches

# Set random seeds
np.random.seed(0)

In [None]:
def train_siamese_network(train_loader, net, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        for i, data in enumerate(train_loader, 0):
            input1, input2, label = data
            output1, output2 = net(input1), net(input2)
            # loss = criterion(output1, output2, label)
            loss = criterion(output1, output2, )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 1000 == 0:
                print(f"Epoch {epoch}, Iteration {i}, Loss {loss.item()}")

In [None]:
# Load data
pathLoader = PathDataLoader()
paths = pathLoader.read('eu_city_2x2_macro_306.bin')[:10000]

In [None]:
# Format data
batch_size = 10
train_val_ratio = 0.95

In [None]:
def generate_pairs(patches, ratio_local, num_pairs = 10000):
    num_patches = len(patches)
    ratio_nonlocal_pair = (1 - ratio_local) / (1 - (1 / num_patches))
    ratio_local_pair = 1 - ratio_nonlocal_pair

    data_pairs = []
    while len(data_pairs) < num_pairs:

      # Pick local pair or non-local pair
      first_patch_index = np.random.randint(num_patches)
      second_patch_index = np.random.randint(num_patches)
      rnd_local = np.random.uniform(0, 1)

      if rnd_local < ratio_local_pair:
        second_patch_index = first_patch_index

      # Pick a random path within a the chosen patch
      rnd_1 = np.random.randint(len(patches[first_patch_index]))
      rnd_2 = np.random.randint(len(patches[second_patch_index]))

      # In case we get same path
      if first_patch_index == second_patch_index:
        while rnd_1 == rnd_2:
          rnd_2 = np.random.randint(len(patches[second_patch_index]))

      if first_patch_index == second_patch_index:
        label = 0
      else:
        label = 1

      data_pairs.append(
      (torch.tensor(patches[first_patch_index][rnd_1], dtype=torch.float), 
      torch.tensor(patches[second_patch_index][rnd_2], dtype=torch.float), 
      torch.tensor(label, dtype=torch.long)))
      
    return data_pairs


In [None]:
def generate_dataloaders(data_pairs, train_val_ratio, batch_size):
    train_size = int(train_val_ratio * len(data_pairs))
    train_data = data_pairs[:train_size]
    val_data = data_pairs[train_size:]

    dataloaders_train = DataLoader(train_data, batch_size, shuffle=True)
    dataloaders_val = DataLoader(val_data, batch_size, shuffle=True)

    return dataloaders_train, dataloaders_val

In [None]:
# Generate patches
gen = PatchGenerator(num_patches = 8, attribute="transmitter")
patches = gen.generate_patches(paths)

# Transform PathPropagation objects to feature vectors
patches = gen.transform_patches(patches)
patches = normalize_patches(patches)

# Generate pairs
patches_pairs = generate_pairs(patches, 0.5, 100000)

#Create dataloaders
dataloader_train, dataloader_val = generate_dataloaders(patches_pairs, train_val_ratio, batch_size)

In [None]:
# Instantiate the Siamese Network and Loss Function
net = SiameseNetworkSimple()
criterion = ContrastiveLossSimple()
optimizer = optim.Adam(net.parameters(), lr=0.0005)

In [None]:
train_siamese_network(dataloader_train, net, criterion, optimizer, epochs=20)

In [None]:
# Step 1 to visualize embeddings
dataloader_object = iter(dataloader_val)
data_batch = next(dataloader_object)

In [None]:
# Step 2 to visualize embeddings
print(data_batch[2]) # See labels
print(data_batch[0][5]) # Change second index to see specific path
print(data_batch[1][5]) 
embeddings = net(data_batch[0])
embeddings2 = net(data_batch[1])
print(embeddings[5]) # Change index to see specific path
print(embeddings2[5])

In [None]:
# Step 3 to visualize embeddings

# Convert embeddings to a list
embeddings_list = embeddings.squeeze().tolist()
embeddings_list2 = embeddings2.squeeze().tolist()

# Create x-axis indices
indices = list(range(len(embeddings_list[0])))

# Plot the 1D embeddings
plt.figure(figsize=(8, 6))
for i in range(len(embeddings_list)):
    plt.plot(indices, embeddings_list[i], marker='o', linestyle='-')
plt.title('Visualization of 1D Embeddings')
plt.xlabel('Component')
plt.ylabel('Embedding Value')
plt.show()

# Plot the 1D embeddings
plt.figure(figsize=(8, 6))
for i in range(len(embeddings_list2)):
    plt.plot(indices, embeddings_list2[i], marker='o', linestyle='-')
plt.title('Visualization of 1D Embeddings')
plt.xlabel('Component')
plt.ylabel('Embedding Value')
plt.show()