In [None]:
import os
import json

import gym
import torch
import numpy as np
from stable_baselines3 import DQN
from PIL import Image
from tensorboardX import SummaryWriter
import scipy.stats as stats
import matplotlib.pyplot as plt

import cloob.clip as clip
from cloob.clip import _transform
from cloob.model import CLIPGeneral
import cloob.zeroshot_data as zeroshot_data

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device is ", device)

In [None]:
class CLOOBEnv:
    def __init__(self, env, cloob_model, cloob_preprocess, tokenizer, prompt, writer):
        self.env = env

        self.model = cloob_model
        self.preprocess = cloob_preprocess
        self.tokenizer = tokenizer

        class_embeddings = self.model.encode_text(self.tokenizer(prompt))
        class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
        class_embedding = class_embeddings.mean(dim=0)
        class_embedding /= class_embedding.norm()
        self.text_features = class_embedding

        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space
        self.metadata = self.env.metadata

        self.cloob_rewards_per_episode = []
        self.env_rewards_per_episode = []

        self.cloob_rewards = []
        self.env_rewards = []

        self.writer = writer
        self.n_steps = 0
        self.n_episodes = 0

    def reset(self):
        return self.env.reset()

    def close(self):
        return self.env.close()

    def step(self, action):
        next_st, rwd, done, info = self.env.step(action)
        img = self.env.render(mode="rgb_array")
        cloob_rwd = self.get_cloob_reward(img)

        self.cloob_rewards_per_episode.append(cloob_rwd)
        self.env_rewards_per_episode.append(rwd)

        if done:
            self.writer.add_scalar('episode_rewards/env_reward',  sum(self.env_rewards_per_episode), self.n_episodes)
            self.writer.add_scalar('episode_rewards/cloob_reward', sum(self.cloob_rewards_per_episode), self.n_episodes)

            self.env_rewards.append(self.env_rewards_per_episode)
            self.cloob_rewards.append(self.cloob_rewards_per_episode)

            self.env_rewards_per_episode = []
            self.cloob_rewards_per_episode = []

            self.n_episodes += 1

        self.n_steps += 1

        return next_st, rwd, done, info

    def get_cloob_reward(self, state):
        with torch.no_grad():
            image = self.preprocess(Image.fromarray(np.uint8(state)))
            image_features = self.model.encode_image(image.unsqueeze(0))
            image_features /= image_features.norm(dim=-1, keepdim=True)
            sim = (image_features @ self.text_features) * 30 # we scale by 30 just so the difference is more stark
        return sim.numpy()[0]

In [None]:
def run_exp(agent, env, prompt, clip_model_name, env_name, exp_path, n_steps, notes=''):
    if not os.path.exists(exp_path):
        os.mkdir(exp_path)

    agent.learn(total_timesteps=n_steps, progress_bar=True)
    agent.save(f"{exp_path}/agent")

    corr = stats.pearsonr([sum(e) for e in env.env_rewards], [sum(e) for e in env.cloob_rewards])[0]
    m_rwd = np.mean([sum(e) for e in env.env_rewards[-10:]])
    results = {
        'env_name': env_name,
        'prompt': prompt,
        'clip_model_name': clip_model_name,
        'correlation': corr,
        'mean_env_rwd_over_last_10_episodes': m_rwd,
        'n_episodes': env.n_episodes,
        'n_steps': env.n_steps,
         'notes': notes,
    }
    with open(f'{exp_path}/results.json', 'w') as f:
        json.dump(results, f)

    # compute correlation between env and clip rewards for each episode separately
    per_episode_corr = [stats.pearsonr(e, c)[0] for e, c in zip(env.env_rewards, env.cloob_rewards)]
    for i in range(env.n_episodes):
        env.writer.add_scalar('Per episode correlation', per_episode_corr[i], i)

In [None]:
# Load CLOOB
checkpoint_path = './checkpoints/cloob_rn50_yfcc_epoch_28.pt'
configs_path = './cloob/model_configs/'

checkpoint = torch.load(checkpoint_path, map_location=device)
model_config_file = os.path.join(configs_path, checkpoint['model_config_file'])

print('Loading model from', model_config_file)
assert os.path.exists(model_config_file)
with open(model_config_file, 'r') as f:
    model_info = json.load(f)
cloob_model = CLIPGeneral(**model_info)
cloob_preprocess= _transform(cloob_model.visual.input_resolution, is_train=False)

if not torch.cuda.is_available():
    cloob_model.float()
else:
    cloob_model.to(device)

sd = checkpoint["state_dict"]
sd = {k[len('module.'):]: v for k, v in sd.items()}
if 'logit_scale_hopfield' in sd:
    sd.pop('logit_scale_hopfield', None)
cloob_model.load_state_dict(sd)
cloob_model.eval()

In [None]:
!mkdir experiments

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/experiments/first_run

In [None]:
!tensorboard --logdir 'experiments/first_run/' --host 0.0.0.0 --port 6006

In [None]:
EXP_PATH = 'experiments/first_run/'
ENV_NAME = 'LunarLander-v2'
N_STEPS = 20000

env = gym.make(ENV_NAME)
prompt = 'Spaceship is on the landing pad'
writer = SummaryWriter(EXP_PATH)

cl_env = CLOOBEnv(env, cloob_model, cloob_preprocess, clip.tokenizer, prompt, writer)

agent = DQN('MlpPolicy', cl_env, verbose=0, learning_starts=1000, buffer_size=15000, target_update_interval=500)
run_exp(agent, cl_env, prompt, 'cloob_rn50_yfcc_epoch_28', ENV_NAME, EXP_PATH, N_STEPS, 'all additional info about experiment goes here')