In [None]:
# Supress warnings
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

# Python
import matplotlib.pyplot as plt

# NumPy and PyTorch
import torch
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader

# Custom
from path_loader import PathDataLoader
from networks import SiameseNetworkSimple
from losses import ContrastiveLossSimple
from patch_generator import PatchGenerator

# Set random seeds
np.random.seed(0)

In [None]:
def train_siamese_network(train_loader, net, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        for i, data in enumerate(train_loader, 0):
            input1, input2, label = data
            output1, output2 = net(input1), net(input2)
            loss = criterion(output1, output2, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 1000 == 0:
                print(f"Epoch {epoch}, Iteration {i}, Loss {loss.item()}")

In [None]:
# Load data
pathLoader = PathDataLoader()
paths = pathLoader.read('eu_city_2x2_macro_306.bin')[:10000]

In [None]:
# Format data
batch_size = 10
train_val_ratio = 0.95

In [None]:
def generate_pairs(patches, ratio_local, num_pairs = 10000):
    num_patches = len(patches)
    ratio_nonlocal_pair = (1 - ratio_local) / (1 - (1 / num_patches))
    ratio_local_pair = 1 - ratio_nonlocal_pair

    data_pairs = []
    while len(data_pairs) < num_pairs:

      # Pick local pair or non-local pair
      first_patch_index = np.random.randint(num_patches)
      second_patch_index = np.random.randint(num_patches)
      rnd_local = np.random.uniform(0, 1)

      if rnd_local < ratio_local_pair:
        second_patch_index = first_patch_index

      # Pick a random path within a the chosen patch
      rnd_1 = np.random.randint(len(patches[first_patch_index]))
      rnd_2 = np.random.randint(len(patches[second_patch_index]))

      # In case we get same path
      if first_patch_index == second_patch_index:
        while rnd_1 == rnd_2:
          rnd_2 = np.random.randint(len(patches[second_patch_index]))

      if first_patch_index == second_patch_index:
        label = 0
      else:
        label = 1

      data_pairs.append(
      (torch.tensor(patches[first_patch_index][rnd_1], dtype=torch.float), 
      torch.tensor(patches[second_patch_index][rnd_2], dtype=torch.float), 
      torch.tensor(label, dtype=torch.long)))
      
    return data_pairs


In [None]:
def generate_dataloaders(data_pairs, train_val_ratio, batch_size):
    train_size = int(train_val_ratio * len(data_pairs))
    train_data = data_pairs[:train_size]
    val_data = data_pairs[train_size:]

    dataloaders_train = DataLoader(train_data, batch_size, shuffle=True)
    dataloaders_val = DataLoader(val_data, batch_size, shuffle=True)

    return dataloaders_train, dataloaders_val

In [None]:
# Generate patches
gen = PatchGenerator(num_patches = 8, attribute="transmitter")
patches = gen.generate_patches(paths)

# Transform PathPropagation objects to normalized feature vectors
patches = gen.transform_patches(patches)

num_paths_in_patches = []
for i in patches:
    num_paths_in_patches.append(len(i))
flattened_patches = [value for patch in patches for path in patch for value in path]
data_min = min(flattened_patches)
data_max = max(flattened_patches)
normalized_patches = [2 * ((x - data_min) / (data_max - data_min)) - 1 for x in flattened_patches]

patches = [[] for i in range(len(num_paths_in_patches))]
c = 0
for i in range(len(num_paths_in_patches)):
    for j in range(num_paths_in_patches[i]):
        patches[i].append([])
        for k in range(21):
            patches[i][j].append(normalized_patches[c])
            c += 1

# Generate pairs
patches_pairs = generate_pairs(patches, 0.5)

#Create dataloaders
dataloader_train, dataloader_val = generate_dataloaders(patches_pairs, train_val_ratio, batch_size)

In [None]:
# Instantiate the Siamese Network and Loss Function
net = SiameseNetworkSimple()
criterion = ContrastiveLossSimple()
optimizer = optim.Adam(net.parameters(), lr=0.0005)

In [None]:
train_siamese_network(dataloader_train, net, criterion, optimizer, epochs=100)

In [None]:
data_iter = iter(dataloader_val)

# Retrieve the first element
first_batch = next(data_iter)

# Access the data from the first batch
first_data = first_batch[0]
embeddings = net(first_data)

In [None]:
yes = True
while yes:
  next_batch = next(data_iter)
  for i in next_batch[2]:
    print(i)
    if i == 0:
      yes = False
      break

In [None]:
print(next_batch[2])
print(next_batch[0][8])
print(next_batch[1][8])
embeddings = net(next_batch[0])
embeddings2 = net(next_batch[1])
print(embeddings[8])
print(embeddings2[8])

In [None]:
# Convert embeddings to a list
embeddings_list = embeddings.squeeze().tolist()
embeddings_list2 = embeddings2.squeeze().tolist()
print(embeddings_list)

# Create x-axis indices
indices = list(range(len(embeddings_list)))

# Plot the 1D embeddings
plt.figure(figsize=(8, 6))
plt.plot(indices, embeddings_list, marker='o', linestyle='-')
plt.title('Visualization of 1D Embeddings')
plt.xlabel('Component')
plt.ylabel('Embedding Value')
plt.show()

# Plot the 1D embeddings
plt.figure(figsize=(8, 6))
plt.plot(indices, embeddings_list2, marker='o', linestyle='-')
plt.title('Visualization of 1D Embeddings')
plt.xlabel('Component')
plt.ylabel('Embedding Value')
plt.show()