In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import networkx as nx
from torch.utils.data import Dataset, DataLoader
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F


class Generator(nn.Module):
    def __init__(self, noise_dim, src_dest_dim, hidden_size, output_size, num_layers=1):
        super(Generator, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embed = nn.Embedding(7, src_dest_dim)  # Assuming 16 unique source/destination values
        self.noise_transform = nn.Linear(src_dest_dim, src_dest_dim)
        self.lstm = nn.LSTM(src_dest_dim, hidden_size, num_layers, batch_first=True)  # *3 for noise, src, dest
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, noise, src, dest):
        src_embedded = self.embed(src).unsqueeze(1)
        dest_embedded =  self.embed(dest).unsqueeze(1)
        noise_transformed = self.noise_transform(noise)
        combined_input = torch.cat([noise_transformed, src_embedded, dest_embedded], dim=1)
        h0 = torch.zeros(self.num_layers, combined_input.size(0), self.hidden_size).to(combined_input.device)
        c0 = torch.zeros(self.num_layers, combined_input.size(0), self.hidden_size).to(combined_input.device)
        out, _ = self.lstm(combined_input, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out


class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1):
        super(Discriminator, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(1, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, path_sequence):
        batch_size = path_sequence.size(0)

        if path_sequence.dim() == 2:
            path_sequence = path_sequence.unsqueeze(-1)
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(path_sequence.device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(path_sequence.device)
        out, (hn, cn) = self.lstm(path_sequence, (h0, c0))
        out = self.fc(out[:, -1, :]) 
        out = self.sigmoid(out)
        
        return out

    

def create_4x4_mesh():
    G = nx.grid_2d_graph(4, 4)
    return G

G = create_4x4_mesh()

def encode_path(path, grid_size=(4, 4)):
    width, height = grid_size
    return [x + y * width for x, y in path]

def decode_path(encoded_path, grid_size=(4, 4)):
    width, height = grid_size
    return [(index % width, index // width) for index in encoded_path]

path = [(0, 0), (1, 0), (1, 1), (1, 2)]  
encoded_path = encode_path(path)
decoded_path = decode_path(encoded_path)

print(f"Path: {path}")
print(f"encoded path:{encoded_path}")
print(f"decoded Path: {decoded_path}")

G = nx.grid_2d_graph(4, 4)

node_list = list(G.nodes())
node_index = {node: idx for idx, node in enumerate(node_list)}


training_data = []
for source in node_list:
    for target in node_list:
        if source != target:
            shortest_path = nx.shortest_path(G, source, target)
            encoded_path = [node_index[node] for node in shortest_path]
            training_data.append({
                'source': node_index[source],
                'target': node_index[target],
                'path': encoded_path
            })
print(training_data)


Path: [(0, 0), (1, 0), (1, 1), (1, 2)]
encoded path:[0, 1, 5, 9]
decoded Path: [(0, 0), (1, 0), (1, 1), (1, 2)]
[{'source': 0, 'target': 1, 'path': [0, 1]}, {'source': 0, 'target': 2, 'path': [0, 1, 2]}, {'source': 0, 'target': 3, 'path': [0, 1, 2, 3]}, {'source': 0, 'target': 4, 'path': [0, 4]}, {'source': 0, 'target': 5, 'path': [0, 1, 5]}, {'source': 0, 'target': 6, 'path': [0, 4, 5, 6]}, {'source': 0, 'target': 7, 'path': [0, 4, 5, 6, 7]}, {'source': 0, 'target': 8, 'path': [0, 4, 8]}, {'source': 0, 'target': 9, 'path': [0, 4, 8, 9]}, {'source': 0, 'target': 10, 'path': [0, 4, 8, 9, 10]}, {'source': 0, 'target': 11, 'path': [0, 1, 2, 3, 7, 11]}, {'source': 0, 'target': 12, 'path': [0, 4, 8, 12]}, {'source': 0, 'target': 13, 'path': [0, 4, 8, 12, 13]}, {'source': 0, 'target': 14, 'path': [0, 4, 5, 6, 10, 14]}, {'source': 0, 'target': 15, 'path': [0, 1, 2, 3, 7, 11, 15]}, {'source': 1, 'target': 0, 'path': [1, 0]}, {'source': 1, 'target': 2, 'path': [1, 2]}, {'source': 1, 'target': 3

In [28]:
noise_dim = 7
src_dest_dim = 7
hidden_size = 7
output_size = 15 * 16

generator = Generator(noise_dim=noise_dim, src_dest_dim=src_dest_dim, hidden_size=hidden_size, output_size=output_size)
discriminator = Discriminator(input_size=16, hidden_size=hidden_size)  # Assuming each step in the path is one-hot encoded

adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.001)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.001)

In [29]:
class PathsDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.max_path_length = max(len(item['path']) for item in data)  # Find the max path length
    def __getitem__(self, idx):
        item = self.data[idx]
        src = torch.tensor(item['source'], dtype=torch.long)
        dest = torch.tensor(item['target'], dtype=torch.long)
        path = torch.tensor(item['path'], dtype=torch.float)
        # Pad the path with 0s to reach the max_path_length
        path_padded = F.pad(path, (0, self.max_path_length - len(path)), "constant", 0)
        return src, dest, path_padded

    def __len__(self):
        return len(self.data)
    
def collate_fn(batch):
    srcs, dests, paths = zip(*batch)
    srcs = torch.tensor(srcs, dtype=torch.long)
    dests = torch.tensor(dests, dtype=torch.long)
    paths_padded = pad_sequence(paths, batch_first=True, padding_value=0)
    return srcs, dests, paths_padded


In [26]:
dataset = PathsDataset(training_data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [30]:
noise_dim = 7
latent_dim = noise_dim  
epochs = 100 
for epoch in range(epochs):
    for i, (src, dest, real_paths) in enumerate(dataloader):
        
        batch_size = real_paths.size(0)
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
        optimizer_D.zero_grad()

        print(f"Real paths shape: {real_paths.shape}")  
        real_paths = real_paths  

        real_loss = adversarial_loss(discriminator(real_paths), real_labels)

        noise = torch.randn(batch_size, 1, latent_dim)
        fake_paths = generator(noise, src, dest).detach()  
        fake_loss = adversarial_loss(discriminator(fake_paths), fake_labels)

        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        optimizer_G.zero_grad()


        noise = torch.randn(batch_size, 1, latent_dim)
        fake_paths = generator(noise, src, dest)


        g_loss = adversarial_loss(discriminator(fake_paths), real_labels)
        g_loss.backward()
        optimizer_G.step()
        

        if (i + 1) % 50 == 0:  # Assuming you want to print every 50 batches
            print(f"Epoch [{epoch+1}/{epochs}], Batch [{i+1}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

Real paths shape: torch.Size([32, 7])


IndexError: index out of range in self