In [None]:
!conda create --name ani python=3.8
conda install -c conda-forge jupyterlab
pip install open_clip_torch
pip install stable-baselines3
pip install gym[all]
pip install pyglet==1.5.27

In [1]:
import gym

from stable_baselines3 import DQN
import open_clip
from PIL import Image
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')
tokenizer = open_clip.get_tokenizer('ViT-B-32-quickgelu')

In [9]:
class CLIPEnv():
    def __init__(self, env, clip_model, tokenizer, prompt):
        self.env = env
        self.model = clip_model
        self.tokenizer = tokenizer
        self.text_features = self.model.encode_text(self.tokenizer([prompt]))
        self.text_features /= self.text_features.norm(dim=-1, keepdim=True)
        
        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space
        self.metadata = self.env.metadata 

    
    def reset(self):
        return env.reset()
    
    def step(self, action):
        next_st, rwd, done, info = self.env.step(action)
        img = env.render(mode="rgb_array")
        clip_rwd = self.get_clip_reward(img)
        print(clip_rwd)
        return next_st, clip_rwd, done, info
    
    def get_clip_reward(self, state):
        image = preprocess(Image.fromarray(np.uint8(state))).unsqueeze(0)
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_features = self.model.encode_image(image)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            # text_probs = (image_features @ self.text_features.T).softmax(dim=-1)
            sim = (image_features @ self.text_features.T)
        return sim.cpu().detach().numpy()[0][0]

In [10]:
env = gym.make('LunarLander-v2')
cl_env = CLIPEnv(env, model, tokenizer, 'Spaceship is on the landing pad')

agent = DQN('MlpPolicy', cl_env, verbose=1)
agent.learn(total_timesteps=100)

obs = env.reset()
for i in range(100):
    action, _states = agent.predict(obs)
    obs, rewards, dones, info = cl_env.step(action)

env.close()

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
0.26732844
0.2596045
0.26246786
0.2609768
0.26264134
0.2623297
0.26166695
0.25834024
0.2501806
0.24856597
0.25076312
0.24031901
0.24935904
0.2500715
0.2596199
0.2648692
0.260556
0.26544598
0.26074168
0.26232168
0.2613571
0.27001944
0.26764834
0.26814714
0.2623817
0.25963897
0.2595424
0.25685906
0.26148784
0.270525
0.2710784
0.26365536
0.2576868
0.26935312
0.25662065
0.26320785
0.25574166
0.26476103
0.26227045
0.2548352
0.25724578
0.26256025
0.25717688
0.26175398
0.27064753
0.25965115
0.2654138
0.26585603
0.27221745
0.2698751
0.26591492
0.27840614
0.28038034
0.27804133
0.27548277
0.2724019
0.27841023
0.27422804
0.2656194
0.2719066
0.27948564
0.2762191
0.26929155
0.26973706
0.2706406
0.2772346
0.2748541
0.28613707
0.27860147
0.2834376
0.28007388
0.27555484
0.27860448
0.27084056
0.27432492
0.2644498
0.26458853
0.24833286
0.2565048
0.26102203
0.23546955
0.24080728
0.25535405
0.2551629
0.24308133
0