In [31]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from typing import List, Tuple
from gym import spaces

from gym.wrappers import FlattenObservation

import collections
from collections import namedtuple, deque
import tqdm
import matplotlib.pyplot as plt
import random
import gymnasium as gym

from IPython.display import clear_output
from IPython import display


torch.set_default_dtype(torch.float64)

In [18]:


class FullyConnectedModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(FullyConnectedModel, self).__init__()

        # Define layers with ReLU activation
        self.linear1 = nn.Linear(input_size, 200)
        self.activation1 = nn.ReLU()
        self.linear2 = nn.Linear(200, 200)
        self.activation2 = nn.ReLU()
        self.linear3 = nn.Linear(200, 50)
        self.activation3 = nn.ReLU()

        # Output layer without activation function
        self.output_layer = nn.Linear(50, output_size)

        # Initialization using Xavier uniform (a popular technique for initializing weights in NNs)
        nn.init.xavier_normal_(self.linear1.weight)
        nn.init.xavier_normal_(self.linear2.weight)
        nn.init.xavier_normal_(self.linear3.weight)
        nn.init.xavier_normal_(self.output_layer.weight)

    def forward(self, inputs):
        # Forward pass through the layers
        x = self.activation1(self.linear1(inputs))
        x = self.activation2(self.linear2(x))
        x = self.activation3(self.linear3(x))
        x = self.output_layer(x)
        return x

    
class QNetwork:
    def __init__(self, env,  lr, input=150 , logdir=None):
        # Define Q-network with specified architecture
        self.net = FullyConnectedModel(input, 3)
        self.env = env
        self.lr = lr 
        self.logdir = logdir
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)

    def load_model(self, model_file):
        # Load pre-trained model from a file
        return self.net.load_state_dict(torch.load(model_file))

    def load_model_weights(self, weight_file):
        # Load pre-trained model weights from a file
        return self.net.load_state_dict(torch.load(weight_file))

In [113]:
class ReplayMemory:
    def __init__(self, env, memory_size=50000, burn_in=10000):
        # Initializes the replay memory, which stores transitions recorded from the agent taking actions in the environment.
        self.memory_size = memory_size
        self.burn_in = burn_in
        self.memory = collections.deque([], maxlen=memory_size)
        self.env = env

    def sample_batch(self, batch_size=32):
        # Returns a batch of randomly sampled transitions to be used for training the model.
        return random.sample(self.memory, batch_size)

    def append(self, transition):
        # Appends a transition to the replay memory.
        self.memory.append(transition)

In [None]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class DQN_Agent:

    def __init__(self, env, lr=5e-4, render=False):
        # Initialize the DQN Agent.
        self.env = env
        self.lr = lr
        self.policy_net = QNetwork(self.env, self.lr)
        self.target_net = QNetwork(self.env, self.lr)
        self.target_net.net.load_state_dict(self.policy_net.net.state_dict())  # Copy the weight of the policy network
        self.rm = ReplayMemory(self.env)
        self.burn_in_memory()
        self.batch_size = 32
        self.gamma = 0.99
        self.c = 0

    def burn_in_memory(self):
        # Initialize replay memory with a burn-in number of episodes/transitions.
        cnt = 0
        terminated = False
        truncated = False
        state, _ = self.env.reset()
        state = torch.tensor(state, dtype=torch.double).unsqueeze(0)

        # Iterate until we store "burn_in" buffer
        while cnt < self.rm.burn_in:
            # Reset environment if terminated or truncated
            if terminated or truncated:
                state, _ = self.env.reset()
                state = torch.tensor(state, dtype=torch.double).unsqueeze(0)
            
            # Randomly select an action (left or right) and take a step
            action = torch.tensor(self.env.action_space.sample()).reshape(1, 1)
            next_state, reward, terminated, truncated, _ = self.env.step(action.item())
            reward = torch.tensor([reward])
            if terminated:
                next_state = None
            else:
                next_state = torch.tensor(next_state, dtype=torch.double).unsqueeze(0)
                
            # Store new experience into memory
            transition = Transition(state, action, next_state, reward)
            self.rm.memory.append(transition)
            state = next_state
            cnt += 1

    def epsilon_greedy_policy(self, q_values, epsilon=0.05):
        # Implement an epsilon-greedy policy. 
        p = random.random()
        if p > epsilon:
            with torch.no_grad():
                return self.greedy_policy(q_values)
        else:
            return torch.tensor([[self.env.action_space.sample()]], dtype=torch.long)

    def greedy_policy(self, q_values):
        # Implement a greedy policy for test time.
        return torch.argmax(q_values)
        
    def train(self):
        # Train the Q-network using Deep Q-learning.
        state, _ = self.env.reset()
        state = torch.tensor(state, dtype=torch.double).unsqueeze(0)
        terminated = False
        truncated = False

        # Loop until reaching the termination state
        while not (terminated or truncated):
            with torch.no_grad():
                q_values = self.policy_net.net(state)

            # Decide the next action with epsilon greedy strategy
            action = self.epsilon_greedy_policy(q_values).reshape(1, 1)
            
            # Take action and observe reward and next state
            next_state, reward, terminated, truncated, _ = self.env.step(action.item())
            reward = torch.tensor([reward])
            if terminated:
                next_state = None
            else:
                next_state = torch.tensor(next_state, dtype=torch.double).unsqueeze(0)

            # Store the new experience
            transition = Transition(state, action, next_state, reward)
            self.rm.memory.append(transition)

            # Move to the next state
            state = next_state

            # Sample minibatch with size N from memory
            transitions = self.rm.sample_batch(self.batch_size)
            batch = Transition(*zip(*transitions))
            non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool)
            non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
            state_batch = torch.cat(batch.state)
            action_batch = torch.cat(batch.action)
            reward_batch = torch.cat(batch.reward)

            # Get current and next state values
            state_action_values = self.policy_net.net(state_batch).gather(1, action_batch) # extract values corresponding to the actions Q(S_t, A_t)
            next_state_values = torch.zeros(self.batch_size)
            
            with torch.no_grad():
                # no next_state_value update if an episode is terminated (next_satate = None)
                # only update the non-termination state values (Ref: https://gymnasium.farama.org/tutorials/gymnasium_basics/handling_time_limits/)
                next_state_values[non_final_mask] = self.target_net.net(non_final_next_states).max(1)[0] # extract max value
                
            # Update the model
            expected_state_action_values = (next_state_values * self.gamma) + reward_batch
            criterion = torch.nn.MSELoss()
            loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
            self.policy_net.optimizer.zero_grad()
            loss.backward()
            self.policy_net.optimizer.step()

            # Update the target Q-network in each 50 steps
            self.c += 1
            if self.c % 50 == 0:
                print(self.test())
                self.target_net.net.load_state_dict(self.policy_net.net.state_dict())

    def test(self, n = 30,model_file=None):
        # Evaluates the performance of the agent over 20 episodes.
        rewards = []
        for i in range(n):
            max_t = 1000
            state, _ = self.env.reset()

            for t in range(max_t):
                state = torch.from_numpy(state).double().unsqueeze(0)
                with torch.no_grad():
                    q_values = self.policy_net.net(state)
                action = self.greedy_policy(q_values)
                state, reward, terminated, truncated, _ = self.env.step(action.item())
                if terminated or truncated:
                    rewards.append(reward)
                    break

        return np.average(rewards)

In [115]:
class balloon3D:
    def __init__(self,width,length,height,obs_size):
        self.obs_size = obs_size
        self.x, self.y , self.z = 0,0,0
        self.height = height
        self.width = width
        self.length = length
        self.map = None

    def reset(self,start_x,start_y,start_z):
        self.x, self.y ,self.z = start_x,start_y,start_z

    def up(self):
        self.z = min(self.height-1,self.z+1)

    def down(self):
        self.z = max(self.z-1,0)

    def step(self):
        newx = int((self.x + self.map[self.x,self.y,self.z,0]) % self.width)
        newy = int((self.y + self.map[self.x,self.y,self.z,1]) % self.length)
        self.x = newx
        self.y = newy

    def generate_map(self,eta,mu):
        self.map = np.zeros((self.width,self.length,self.height,2))
        for i in range(self.width):
            for j in  range(self.length):
                for k in range(self.height):
                    for l in range(2):
                        r = np.random.random()
                        if(r>eta):
                            r= np.random.random()
                            if(r>mu):
                                self.map[i,j,k,l] = 1
                            else:
                                self.map[i,j,k,l] = -1


    def set_map(self,map):
        self.map = map

    def get_winds(self):
        r = self.obs_size
        self.obs = np.zeros((2*r+1,2*r+1,self.height,2))
        for i in range(self.x - r , self.x + r + 1):
            for j in range(self.y - r, self.y + r + 1):
                for k in range(0,self.height):
                    for l in range(2):
                        self.obs[i - self.x + r, j- self.y + r,k,l]  = (self.map[i%self.width,j%self.length,k,l])
        return self.obs
    
    def render_obs(self):
        r =self.obs_size*2+1
        for z in range(self.height):
            for j in range(r):
                s = ''
                for i in range(r):
                    l = ''
                    if(self.obs[i,j,z,0] == 0):
                        l += '.'
                    elif(self.obs[i,j,z,0] == 1):
                        l+='>'
                    elif(self.obs[i,j,z,0] == -1):
                        l+='<'

                    if(self.obs[i,j,z,1] == 0):
                        l += '.'
                    elif(self.obs[i,j,z,1] == 1):
                        l+='v'
                    elif(self.obs[i,j,z,1] == -1):
                        l+='^'                  
                    s+= l+ " "
                print(s)
            print(" ")
    

    
    def render(self):
        for z in range(self.height):
            for j in range(self.length):
                s = ''
                for i in range(self.width):
                    if self.x == i and self.y == j and self.z == z : 
                        s+= 'OO '
                    else:
                        l = ''
                        if(self.map[i,j,z,0] == 0):
                            l += '.'
                        elif(self.map[i,j,z,0] == 1):
                            l+='>'
                        elif(self.map[i,j,z,0] == -1):
                            l+='<'

                        if(self.map[i,j,z,1] == 0):
                            l += '.'
                        elif(self.map[i,j,z,1] == 1):
                            l+='v'
                        elif(self.map[i,j,z,1] == -1):
                            l+='^'
                        
                        s+= l+ " "
                print(s)
            print(" ")
                
        

In [116]:
class BalEnv(gym.Env):
    def __init__(self, render_mode=None, s_x = 10 , s_y =10 , s_z = 3, obs_size = 5 ):
        self.balloon = balloon3D(s_x,s_y,s_z,obs_size)
        self.obs_size = obs_size

        self._target_location = [s_x-1,s_y-1,s_z-1]


        # Observations are dictionaries with the agent's and the target's location.
        # Each location is encoded as an element of {0, ..., `size`}^2, i.e. MultiDiscrete([size, size]).
        self.observation_space = spaces.Dict(
            {
                #"agent": spaces.Box(0, size - 1, shape=(2,), dtype=int),
                #"target": spaces.Box(0, size - 1, shape=(2,), dtype=int),
                "map": spaces.Box(-1, 1, shape=(2*obs_size + 1,obs_size*2+1,s_z,2), dtype=int),
            }
        )
        

        self.action_space = spaces.Discrete(3)

    def reset(self):
        self.evo_time = 100

        self.reward = 0

        self.reward_x, self.reward_y = 0,0

        self.balloon.generate_map(0.35,0.5)

        self.time = 0

        self.balloon.reset(4,4,0)

        self._agent_location = [0,0,0]


        observation = self._get_obs()

        return observation, {}

    def step(self, action):
        self.time += 1

        if action == 0:
            self.balloon.step()
        if action == 1:
            self.balloon.up()
        if action == 2:
            self.balloon.down()


        
        self._agent_location = np.array([self.balloon.x , self.balloon.y,self.balloon.z])
        terminated  = ((self._agent_location[0]==self._target_location[0])and(self._agent_location[1]==self._target_location[1]))
        
        distx = min(((self._agent_location[0] - self._target_location[0])%self.balloon.width),(self._target_location[0] - self._agent_location[0])%self.balloon.width)
        disty = min(((self._agent_location[1] - self._target_location[1])%self.balloon.length),(self._target_location[1] - self._agent_location[1])%self.balloon.length)



        self.reward = -(np.sqrt((distx)**2 + (disty)**2))

        observation = self._get_obs()

        trucated = (self.time == 50)

        return observation, self.reward, terminated, trucated, {}
    
    def _get_obs(self):
        return {"map" : self.balloon.get_winds()}

    def render(self):
        self.balloon.render()
        
    def close(self):
        pass

In [118]:
env = BalEnv(obs_size=2)
wrapped_env = FlattenObservation(env)
print(wrapped_env.action_space)

training = DQN_Agent(wrapped_env)

Discrete(3)


In [135]:
training.test()

-6.114220690017832

In [134]:
for i in range(500):
    print(i)
    training.train()


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225


KeyboardInterrupt: 