In [None]:
!conda create --name ani python=3.8
conda install -c conda-forge jupyterlab
pip install open_clip_torch
pip install stable-baselines3
pip install gym[all]
pip install pyglet==1.5.27

In [1]:
import gym

from stable_baselines3 import DQN
import open_clip
from PIL import Image
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')
tokenizer = open_clip.get_tokenizer('ViT-B-32-quickgelu')

In [9]:
class CLIPEnv():
    def __init__(self, env, clip_model, tokenizer, prompt):
        self.env = env
        self.model = clip_model
        self.tokenizer = tokenizer
        self.text_features = self.model.encode_text(self.tokenizer([prompt]))
        self.text_features /= self.text_features.norm(dim=-1, keepdim=True)
        
        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space
        self.metadata = self.env.metadata 

    
    def reset(self):
        return env.reset()
    
    def step(self, action):
        next_st, rwd, done, info = self.env.step(action)
        img = env.render(mode="rgb_array")
        clip_rwd = self.get_clip_reward(img)
        return next_st, clip_rwd, done, info
    
    def get_clip_reward(self, state):
        image = preprocess(Image.fromarray(np.uint8(state))).unsqueeze(0)
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_features = self.model.encode_image(image)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            # text_probs = (image_features @ self.text_features.T).softmax(dim=-1)
            sim = (image_features @ self.text_features.T)
        return sim[0].cpu().detach().numpy()  
    
    def observation_space():
        return self.env.observation_space

In [10]:
env = gym.make('LunarLander-v2')
cl_env = CLIPEnv(env, model, tokenizer, 'Spaceship is on the landing pad')

agent = DQN('MlpPolicy', cl_env, verbose=1)
agent.learn(total_timesteps=100)

obs = env.reset()
for i in range(100):
    action, _states = agent.predict(obs)
    obs, rewards, dones, info = cl_env.step(action)

env.close()

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
tensor([[0.2252]])
tensor([[0.2096]])
tensor([[0.2070]])
tensor([[0.2160]])
tensor([[0.2219]])
tensor([[0.2251]])
tensor([[0.2212]])
tensor([[0.2187]])
tensor([[0.2126]])
tensor([[0.2133]])
tensor([[0.2259]])
tensor([[0.2235]])
tensor([[0.2238]])
tensor([[0.1905]])
tensor([[0.1905]])
tensor([[0.2143]])
tensor([[0.2127]])
tensor([[0.2133]])
tensor([[0.2075]])
tensor([[0.2125]])
tensor([[0.1905]])
tensor([[0.1905]])
tensor([[0.1905]])
tensor([[0.1905]])
tensor([[0.1905]])
tensor([[0.2272]])
tensor([[0.2120]])
tensor([[0.2138]])
tensor([[0.2194]])
tensor([[0.2151]])
tensor([[0.2181]])
tensor([[0.1905]])
tensor([[0.1905]])
tensor([[0.1905]])
tensor([[0.1905]])
tensor([[0.2079]])
tensor([[0.2168]])
tensor([[0.2138]])
tensor([[0.2247]])
tensor([[0.2301]])
tensor([[0.2250]])
tensor([[0.2288]])
tensor([[0.2352]])
tensor([[0.2097]])
tensor([[0.2217]])
tensor([[0.2310]])
tensor([[0.2260]])
tensor([[0.2

KeyboardInterrupt: 