In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import sys
import pickle
import random

class DFNetwork(nn.Module):
    def __init__(self, state_action_vector_dim, hidden_dim, timesteps):
        super(DFNetwork, self).__init__()
        self.time_embedding = nn.Embedding(timesteps, hidden_dim)
        
        self.fc1 = nn.Linear(state_action_vector_dim + hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, state_action_vector_dim)

    def forward(self, x, t):
        t_embedding = self.time_embedding(t)
        x_t = torch.cat([x, t_embedding], dim=1)
        x_t = F.relu(self.fc1(x_t))
        x_t = F.relu(self.fc2(x_t))
        x_t = self.fc3(x_t)
        return x_t
    

class DiffusionModel(nn.Module):
    def __init__(self, timesteps=100):
        super(DiffusionModel, self).__init__()
        beta_start = 0.0001
        beta_end = 0.02
        self.timesteps = timesteps
        self.betas = self.linear_schedule(beta_start, beta_end, timesteps)
        self.alphas = 1.0 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

    def linear_schedule(self, beta_start, beta_end, timesteps):
        '''
        Linear schedule for betas
        '''
        betas = torch.linspace(beta_start, beta_end, timesteps)
        return betas

    def forward(self, x, t):
        '''
        Input : x_0 initial state-action vector, t current timestep
        Output : x_t noisy state-action vector
        x_t = sqrt(alpha_bars_t) * x_0 + sqrt(1 - alpha_bars_t) * epsilon
        where epsilon ~ N(0, 1)
        '''

        noise = torch.randn_like(x)
        alpha_bars_t = self.alpha_bars[t].view(-1, 1)
        x_t = torch.sqrt(alpha_bars_t) * x + torch.sqrt(1.0 - alpha_bars_t) * noise

        return x_t, noise

    def reverse_diffusion(self, model, x_t, t):
        '''
        Reverse diffusion process
        '''
        pred_noise = model(x_t, t)
        alpha_bars_t = self.alpha_bars[t].view(-1, 1)
        alpha_bars_prev = self.alpha_bars[t-1].view(-1, 1) if t > 0 else torch.ones_like(alpha_bars_t)

        return pred_noise



def train_diffusion_model(network, diffusion_model, dataloader, epochs, lr):
    optimizer = torch.optim.Adam(network.parameters(), lr=lr)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        epoch_loss = 0
        for batch_idx, data in enumerate(dataloader):
            optimizer.zero_grad()
            
            t = torch.randint(0, diffusion_model.timesteps, (data.shape[0],))
            x_t, noise = diffusion_model(data, t)

            pred_noise = network(x_t, t)

            loss = criterion(pred_noise, noise)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            if batch_idx % 1 == 0:
                print(f"Epoch {epoch+1}/{epochs} Batch {batch_idx+1}/{len(dataloader)} Loss: {loss.item():.4f}")

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs} Loss: {avg_loss:.4f}")





with open("Expert_trajectories/train/1.pkl", "rb") as f:
    data = pickle.load(f)

data_list = []
for i in range(len(data)):
    for j in range(len(data[i].acts)):
        obs = data[i].obs[j].tolist()
        act = data[i].acts[j].tolist()
        data_object = obs + act
        data_list.append(data_object)

data_list = torch.tensor(data_list)

dataloader = torch.utils.data.DataLoader(
    data_list,
    batch_size=32,
    shuffle=True
)

state_action_vector_dim = 171
timesteps = 100
DiffusionModel = DiffusionModel(timesteps)
DFNetwork = DFNetwork(state_action_vector_dim, 128, timesteps)
train_diffusion_model(DFNetwork, DiffusionModel, dataloader, 500, 0.001)

Epoch 1/500 Batch 1/13 Loss: 1.0191
Epoch 1/500 Batch 2/13 Loss: 0.9721
Epoch 1/500 Batch 3/13 Loss: 1.0111
Epoch 1/500 Batch 4/13 Loss: 1.0129
Epoch 1/500 Batch 5/13 Loss: 0.9867
Epoch 1/500 Batch 6/13 Loss: 1.0349
Epoch 1/500 Batch 7/13 Loss: 1.0393
Epoch 1/500 Batch 8/13 Loss: 1.0398
Epoch 1/500 Batch 9/13 Loss: 0.9842
Epoch 1/500 Batch 10/13 Loss: 1.0074
Epoch 1/500 Batch 11/13 Loss: 0.9855
Epoch 1/500 Batch 12/13 Loss: 1.0077
Epoch 1/500 Batch 13/13 Loss: 0.9868
Epoch 1/500 Loss: 1.0067
Epoch 2/500 Batch 1/13 Loss: 0.9765
Epoch 2/500 Batch 2/13 Loss: 1.0019
Epoch 2/500 Batch 3/13 Loss: 0.9761
Epoch 2/500 Batch 4/13 Loss: 0.9773
Epoch 2/500 Batch 5/13 Loss: 1.0191
Epoch 2/500 Batch 6/13 Loss: 0.9800
Epoch 2/500 Batch 7/13 Loss: 0.9764
Epoch 2/500 Batch 8/13 Loss: 1.0185
Epoch 2/500 Batch 9/13 Loss: 1.0263
Epoch 2/500 Batch 10/13 Loss: 0.9916
Epoch 2/500 Batch 11/13 Loss: 0.9929
Epoch 2/500 Batch 12/13 Loss: 1.0275
Epoch 2/500 Batch 13/13 Loss: 1.0069
Epoch 2/500 Loss: 0.9978
Epoch 

In [10]:
class PolicyNetwork(nn.Module):
    def __init__(self, state_vector_dim, action_vector_dim, action_min, action_max):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_vector_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_vector_dim)
        self.action_min = torch.tensor(action_min)
        self.action_max = torch.tensor(action_max)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        x = self.action_min + (0.5 * (self.action_max - self.action_min) * (x + 1))

        return x

def diffusion_augmented_policy(policy_network, network, diffusion_model, dataloader, epochs, lmda, lr):
    optimizer = torch.optim.Adam(policy_network.parameters(), lr=lr)

    for epoch in range(epochs):
        epoch_loss = 0
        for batch_idx, data in enumerate(dataloader):
            optimizer.zero_grad()

            # L_BC
            batch_obs = data[:, :165]
            batch_act = data[:, 165:]
            pred_act = policy_network(batch_obs)
            L_BC = nn.MSELoss()(pred_act, batch_act)

            # L_expert
            t = torch.randint(0, diffusion_model.timesteps, (data.shape[0],))
            x_t_real_action, noise_real_action = diffusion_model(data, t)
            noise_pred_real_action = network(x_t_real_action, t)
            loss_expert = nn.MSELoss()(noise_pred_real_action, noise_real_action)

            # L_agent
            batch_obs = data[:, :165]
            batch_act = pred_act
            data_pred_action = torch.cat([batch_obs, batch_act], dim=1)
            x_t_pred_action, noise_pred_action = diffusion_model(data_pred_action, t)
            noise_pred_pred_action = network(x_t_pred_action, t)
            loss_agent = nn.MSELoss()(noise_pred_pred_action, noise_pred_action)

            L_DM = max(loss_agent - loss_expert, 0)

            loss = L_BC + lmda * L_DM

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            if batch_idx % 1 == 0:
                print(f"Epoch {epoch+1}/{epochs} Batch {batch_idx+1}/{len(dataloader)} Loss: {loss.item():.4f}")

        avg_loss = epoch_loss / len(dataloader)

        print(f"Epoch {epoch+1}/{epochs} Loss: {avg_loss:.4f}")


import AnalysisEnv

dataset_path = "Datasets/Networks/1.tsv"    
Env = AnalysisEnv.AnalysisEnv(dataset_path)

min_action, max_action = Env.get_action_space()

state_vector_dim = 165
action_vector_dim = 6
PolicyNetwork = PolicyNetwork(state_vector_dim, action_vector_dim, min_action, max_action)
diffusion_augmented_policy(PolicyNetwork, DFNetwork, DiffusionModel, dataloader, 500, 0.1, 0.001)

# Save the model
torch.save(PolicyNetwork.state_dict(), "Models/Networks/policy.pth")

Epoch 1/500 Batch 1/13 Loss: 33.8063
Epoch 1/500 Batch 2/13 Loss: 26.6174
Epoch 1/500 Batch 3/13 Loss: 27.6443
Epoch 1/500 Batch 4/13 Loss: 23.5096
Epoch 1/500 Batch 5/13 Loss: 23.3918
Epoch 1/500 Batch 6/13 Loss: 19.1172
Epoch 1/500 Batch 7/13 Loss: 29.8027
Epoch 1/500 Batch 8/13 Loss: 11.0019
Epoch 1/500 Batch 9/13 Loss: 16.8101
Epoch 1/500 Batch 10/13 Loss: 26.9266
Epoch 1/500 Batch 11/13 Loss: 15.4824
Epoch 1/500 Batch 12/13 Loss: 19.2892
Epoch 1/500 Batch 13/13 Loss: 25.6295
Epoch 1/500 Loss: 23.0022
Epoch 2/500 Batch 1/13 Loss: 18.7837
Epoch 2/500 Batch 2/13 Loss: 19.2221
Epoch 2/500 Batch 3/13 Loss: 16.5686
Epoch 2/500 Batch 4/13 Loss: 17.8999
Epoch 2/500 Batch 5/13 Loss: 12.2979
Epoch 2/500 Batch 6/13 Loss: 19.6752
Epoch 2/500 Batch 7/13 Loss: 23.1341
Epoch 2/500 Batch 8/13 Loss: 15.3145
Epoch 2/500 Batch 9/13 Loss: 22.3246
Epoch 2/500 Batch 10/13 Loss: 22.4857
Epoch 2/500 Batch 11/13 Loss: 18.2210
Epoch 2/500 Batch 12/13 Loss: 13.5881
Epoch 2/500 Batch 13/13 Loss: 27.5248
Epoc

In [11]:
Env.reset()
print(obs)

for i in range(12):
    # print('Observation at timestep', i, ':', obs)
    obs = torch.tensor(obs).float()
    action = PolicyNetwork(obs)
    action = action.detach().numpy()

    print('Action at timestep', i, ':', action)
    actual_action = Env.preprocess_action(action)
    print('Action at timestep', i, ':', actual_action)
    obs, _, done, _ = Env.step(action)
    if done == 1:
        break

None


[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0011563367443159223, 0.0, 0.1633061170578003, 0.0002312673459528014, 0.0, 0.03441566973924637, 0.0002312673459528014, 0.0, 0.03441566973924637, 0.0003469010116532445, 0.0, 0.03907942771911621, 0.9422987699508667, 0.0, 0.9882895350456238, 0.42711707949638367, 0.0004625346919056028, 0.850688099861145, 0.02082369290292263, 0.0004625346919056028, 0.08669976145029068, 0.0011563367443159223, 0.0, 0.1633061170578003, 1.0, 0.0, 0.9999997615814209, 0.0