## 1. Install necessary Libraries

In [1]:
# ! pip install open_clip_torch matplotlib
# ! pip install opencv-python
# ! sudo apt-get update -y
# ! sudo apt-get install libgl1-mesa-glx -y

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import os
import open_clip
open_clip.list_pretrained()
from open_clip import tokenizer

In [2]:
model, _, preprocess = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k_augreg')

## 2. Generate Appropriate Action descriptions for Driving

In [3]:
descriptions = []

# Throttle values range from 0 to 1 in increments of 0.25
throttle_values = [i * 0.25 for i in range(5)]

# Steer values range from -1 to 1 in increments of 0.1
steer_values = [i * 0.1 for i in range(-10, 11)]

# Brake values range from 0 to 1 in increments of 0.25
brake_values = [i * 0.25 for i in range(5)]

# Constructing the action descriptions
for throttle in throttle_values:
    for steer in steer_values:
        for brake in brake_values:
            description = f"In order to drive adhering to traffic rules, you need to apply a throttle of {throttle}, a steer value of {steer}, and a brake value of {brake}."
            descriptions.append(description)


len(descriptions)  # No of actions for driving

525

In [3]:
text_ob = tokenizer.tokenize(descriptions)
model = model.to("cuda:0")
with torch.no_grad():            
    text_features_og = model.encode_text(text_ob.to("cuda:0"), normalize=True) 

## 3. Create a Policy Network Using CLIP model

In [5]:
for param in model.parameters():
    param.requires_grad = False

model, _, preprocess = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k_augreg')

class PPO_Network(nn.Module):
    # nature paper architecture
    
    def __init__(self):
        super(PPO_Network, self).__init__()

        self.model = model
        self.preprocess = preprocess
        
        # Define MLP layers for image encoder
        self.image_mlp = nn.Sequential(
            nn.Linear(640, 640),  
            nn.ReLU(),
            nn.Linear(640, 640)   
        )

        # Define MLP layers for text encoder
        self.text_mlp = nn.Sequential(
            nn.Linear(640, 640),  # Assuming the text encoder also outputs 1000 units
            nn.ReLU(),
            nn.Linear(640, 640)   # The output size is set back to 1000 to match the text encoder output
        )

        # Define Value Head
        self.value_head = nn.Sequential(
            nn.Linear(640*2, 1),  # Concatenate image and text features
            
        )
    
    def forward(self, image, text_features_og):
        # Get the Image features from frozen model
        image_features_og = self.model.encode_image(image, normalize=True) 
        
        # # Get the Text features from frozen model
        # text_features_og = self.model.encode_text(text, normalize=True) 
        
        # Pass through the trainable params
        image_features_ba = self.image_mlp(image_features_og)
        text_features_ba = self.text_mlp(text_features_og)

        image_features_ba = F.normalize(image_features_ba, dim=-1) 
        text_features_ba = F.normalize(text_features_ba, dim=-1)

        action_logits = image_features_ba @ text_features_ba.T

        action_probs = action_logits.softmax(dim=-1)

        weighted_text_features = action_probs@text_features_ba  

        combined_features = torch.cat((image_features_ba, weighted_text_features), dim=1)
        
        # Compute the value
        value = self.value_head(combined_features)
        

        return action_probs, value

## 4. Wrap the Policy Network and make it PPO Agent

In [6]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class PPO_Agent(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.network = PPO_Network()
    
    def forward(self, image, text_features_og):
        policy, value = self.network(image, text_features_og)
        return policy, value
    
    def select_action(self, policy):
        return np.random.choice(range(525) , 1, p=policy)[0]

    def parallel_select_action(self, policy):
        # policy shape: [N, 525]
        actions = []
        for p in policy:
            # print(p)
            action = np.random.choice(range(4),1,p=p)[0]
            actions.append(action)
        return actions

## 5. Creating dataset loader from Carla (Note: you need to connect to CARLA simulator to generate dataset)

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class Batch_DataSet(torch.utils.data.Dataset):

    def __init__(self, obs, actions, adv, v_t, old_action_prob):
        super().__init__()
        self.obs = obs
        self.actions = actions
        self.adv = adv
        self.v_t = v_t
        self.old_action_prob = old_action_prob
        
    def __len__(self):
        return self.obs.shape[0]
    
    def __getitem__(self, i):
        return self.obs[i],self.actions[i],self.adv[i],self.v_t[i],self.old_action_prob[i]

## 6. Generating Parallelenvs for having multiple runs and collecting more data in parallel

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
import time

import threading
from queue import Queue

device = torch.device("cuda:0")
dtype = torch.float

class Logger:
    
    def __init__(self, filename):
        self.filename = filename
        
        f = open(f"{self.filename}.csv", "w")
        f.close()
        
    def log(self, msg):
        f = open(f"{self.filename}.csv", "a+")
        f.write(f"{msg}\n")
        f.close()

cur_step = 0          
class Env_Runner:
    
    def __init__(self, envs, agent, logger_folder):
        super().__init__()
        
        self.envs = envs
        self.agent = agent
        
        self.logger = Logger(f'{logger_folder}/training_info')
        self.logger.log("training_step, return")
        
        # self.obs = self.env.reset()
        self.parallel_ob = [env.reset() for env in self.envs]

    def preprocess_image(self, image, queue):
        processed_image = self.agent.network.preprocess(Image.fromarray(image).convert("RGB"))
        queue.put(processed_image)

    def parallel_preprocess(self, images):
        queue = Queue()
        threads = []

        # Start threads for preprocessing
        for image in images:
            thread = threading.Thread(target=self.preprocess_image, args=(image, queue))
            thread.start()
            threads.append(thread)

        # Wait for all threads to complete
        for thread in threads:
            thread.join()

        # Collect results
        processed_images = [queue.get() for _ in images]
        return processed_images
        
    def run(self, steps):
        
        global cur_step
        
        parallel_obs = []
        parallel_actions = []
        parallel_rewards = []
        parallel_dones = []
        parallel_values = []
        parallel_action_prob = []
        
        for step in range(steps):

            start_time = time.time()
            parallel_image_ob = self.parallel_preprocess(self.parallel_ob)
            parallel_image_ob = torch.stack(parallel_image_ob)
            end_time = time.time()

            # print("profile 1",end_time - start_time)

            # image_ob = agent.network.preprocess(Image.fromarray(self.ob).convert("RGB"))
            start_time = time.time()
            self.parallel_ob = parallel_image_ob.to(device).to(dtype)
            
            with torch.no_grad():
                parallel_policy, parallel_value = self.agent(self.parallel_ob, text_features_og.to(device))

            # print(parallel_policy.shape)
            parallel_action = self.agent.parallel_select_action(parallel_policy.detach().cpu().numpy())
            
            # end_time = time.time()

            # print("profile 2",end_time - start_time)

            start_time = time.time()
            for my_ob in self.parallel_ob:
                parallel_obs.append(my_ob)
            for action in parallel_action:
                parallel_actions.append(action)
            for value in parallel_value:
                parallel_values.append(value.detach()) 
            for i, policy in enumerate(parallel_policy):
                parallel_action_prob.append(policy[parallel_action[i]].detach())

            # end_time = time.time()

            # print("profile 3",end_time - start_time)

            self.parallel_ob = []
            parallel_r = []
            parallel_done = []
            parallel_info = []
            parallel_additional_done = []

            # start_time = time.time()
            for i,action in enumerate(parallel_action):
            
                observation, r, done, info, additional_done = self.envs[i].step(action)

                self.parallel_ob.append(observation)
                parallel_r.append(r)
                parallel_done.append(done)
                parallel_info.append(info)
                parallel_additional_done.append(additional_done)
            # end_time = time.time()

            # print("profile 4",end_time - start_time)

            start_time = time.time()
            for i,done in enumerate(parallel_done):
                if done: # real environment reset, other add_dones are for learning purposes
                    self.parallel_ob[i] = self.envs[i].reset()
                    if "return" in parallel_info[i]:
                        self.logger.log(f'{cur_step+step},{parallel_info[i]["return"]}')
            
            for r in parallel_r:
                parallel_rewards.append(r)
            for i in range(0,len(parallel_done)):
                parallel_dones.append(parallel_done[i] or parallel_additional_done[i])

            # end_time = time.time()

            # print("profile 5",end_time - start_time)
            
        cur_step += steps
                                    
        return [parallel_obs, parallel_actions, parallel_rewards, parallel_dones, parallel_values, parallel_action_prob]

## 7. Computing Advantage estimates necessary for PPO training

In [15]:
import numpy as np
import argparse
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time
import random
import os

# import multiprocessing
from PIL import Image
# from functools import partial
import threading
from queue import Queue

from torch.utils.data import DataLoader
from torch.utils.data import Dataset

device = torch.device("cuda:0")
dtype = torch.float



def compute_advantage_and_value_targets(rewards, values, dones, gamma, lam):
    
    advantage_values = []
    old_adv_t = torch.tensor(0.0).to(device)
    
    value_targets = []
    old_value_target = values[-1]
    
    for t in reversed(range(len(rewards)-1)):
        
        if dones[t]:
            old_adv_t = torch.tensor(0.0).to(device)
        
        # ADV
        delta_t = rewards[t] + (gamma*(values[t+1])*int(not dones[t+1])) - values[t]
        
        A_t = delta_t + gamma*lam*old_adv_t
        advantage_values.append(A_t[0])
        
        old_adv_t = delta_t + gamma*lam*old_adv_t
        
        # VALUE TARGET
        value_target = rewards[t] + gamma*old_value_target*int(not dones[t+1])
        value_targets.append(value_target[0])
        
        old_value_target = value_target
    
    advantage_values.reverse()
    value_targets.reverse()
    
    return advantage_values, value_targets

## 8. Setting up hyper-parameters

In [16]:
folder_name = time.asctime(time.gmtime()).replace(" ","_").replace(":","_")
os.mkdir(folder_name)

# save the hyperparameters in a file
f = open(f'{folder_name}/args.txt','w')
# for i in args.__dict__:
#     f.write(f'{i},{args.__dict__[i]}\n')
# f.close()

# arguments
env_name = 'Carla'
num_stacked_frames = 4
start_lr = 2.5e-4 
gamma = 0.99
lam = 0.95
minibatch_size = 32
T = 129
c1 = 1.0
c2 = 0.01
actors = 8
start_eps = 0.1
epochs = 3
total_steps = 10000000
save_model_steps = 1000000
report = 50000


agent = PPO_Agent().to(device)
optimizer = optim.Adam(agent.parameters(), lr=start_lr)

envs = []
env_runners = []
for actor in range(actors):
    
    raw_env = gym.make(env_name)
    env = Carla_Wrapper(raw_env, env_name, num_stacked_frames, use_add_done=True)

    envs.append(env)
    
env_runners.append(Env_Runner(envs, agent, folder_name))

num_model_updates = 0

## 9. Training CLIP model using it as a policy network in Proximal Policy Optimization

In [None]:
# import time

# start_time = time.time()
start_time = time.time()
while cur_step < total_steps:

    # change lr and eps over time
    alpha = 1 - (cur_step*8 / total_steps)
    current_lr = start_lr * alpha
    current_eps = start_eps * alpha
    
    #set lr
    for g in optimizer.param_groups:
        g['lr'] = current_lr
    
    # get data
    batch_obs, batch_actions, batch_adv, batch_v_t, batch_old_action_prob = None, None, None, None, None
    
    for env_runner in env_runners:
        obs, actions, rewards, dones, values, old_action_prob = env_runner.run(T)
        adv, v_t = compute_advantage_and_value_targets(rewards, values, dones, gamma, lam)
    
        # assemble data from the different runners 
        batch_obs = torch.stack(obs[:-1]) if batch_obs == None else torch.cat([batch_obs,torch.stack(obs[:-1])])
        batch_actions = np.stack(actions[:-1]) if batch_actions is None else np.concatenate([batch_actions,np.stack(actions[:-1])])
        batch_adv = torch.stack(adv) if batch_adv == None else torch.cat([batch_adv,torch.stack(adv)])
        batch_v_t = torch.stack(v_t) if batch_v_t == None else torch.cat([batch_v_t,torch.stack(v_t)]) 
        batch_old_action_prob = torch.stack(old_action_prob[:-1]) if batch_old_action_prob == None else torch.cat([batch_old_action_prob,torch.stack(old_action_prob[:-1])])
    
    # load into dataset/loader
    dataset = Batch_DataSet(batch_obs,batch_actions,batch_adv,batch_v_t,batch_old_action_prob)
    dataloader = DataLoader(dataset, batch_size=minibatch_size, num_workers=0, shuffle=True)

    # update
    for epoch in range(epochs):
         
        # sample minibatches
        for i, batch in enumerate(dataloader):
            optimizer.zero_grad()
            
            if i >= 8:
                break
            
            # get data
            obs, actions, adv, v_target, old_action_prob = batch 
            
            # adv = adv.squeeze(1)
            # normalize adv values
            adv = ( adv - torch.mean(adv) ) / ( torch.std(adv) + 1e-8)
            
            # get policy actions probs for prob ratio & value prediction
            policy, v = agent(obs, text_features_og)
            # get the correct policy actions
            pi = policy[range(minibatch_size),actions.long()]
            
            # probaility ratio r_t(theta)
            probability_ratio = pi / (old_action_prob + 1e-8)
            
            # compute CPI
            CPI = probability_ratio * adv
            # compute clip*A_t
            clip = torch.clamp(probability_ratio,1-current_eps,1+current_eps) * adv     
            
            # policy loss | take minimum
            L_CLIP = torch.mean(torch.min(CPI, clip))
            
            # value loss | mse
            L_VF = torch.mean(torch.pow(v - v_target,2))
            
            # policy entropy loss 
            S = torch.mean( - torch.sum(policy * torch.log(policy + 1e-8),dim=1))

            loss = - L_CLIP + c1 * L_VF - c2 * S
            loss.backward()
            print((cur_step-1)*8,loss.item())
            optimizer.step()
    
        num_model_updates += 1

        # print time
        if cur_step%50000 < T*actors:
            end_time = time.time()
            print(f'*** total steps: {cur_step} | time(50K): {end_time - start_time} ***')
            start_time = time.time()
        
        # save the network after some time
        if cur_step%save_model_steps < T*actors:
            torch.save(agent,f'{folder_name}/{env_name}-{cur_step}.pt')

1024 0.180461585521698
1024 0.136705219745636
1024 0.1776484102010727
1024 0.1103881299495697
1024 0.07995184510946274
1024 0.14868797361850739
1024 0.17089125514030457
1024 0.1291157603263855
*** total steps: 129 | time(50K): 67.93120288848877 ***
1024 0.28440141677856445
1024 0.3143439292907715
1024 0.22455736994743347
1024 0.014558900147676468
1024 0.1569065898656845
1024 0.22965970635414124
1024 0.10162125527858734
1024 0.31349024176597595
*** total steps: 129 | time(50K): 1.3328235149383545 ***
1024 0.07154329866170883
1024 0.20368950068950653
1024 0.00512306671589613
1024 0.26747676730155945
1024 0.14804409444332123
1024 0.11977406591176987
1024 0.06720372289419174
1024 0.18783649802207947
*** total steps: 129 | time(50K): 1.6580190658569336 ***
2056 0.003173859789967537
2056 0.01168384961783886
2056 0.028474578633904457
2056 0.020921753719449043
2056 0.0350448302924633
2056 0.031031202524900436
2056 0.02575899474322796
2056 0.005799713544547558
*** total steps: 258 | time(50K): 