### Import

In [None]:
import gymnasium as gym
from gymnasium import Env
from gymnasium.spaces import Discrete, Box
import vizdoom
import numpy as np
import mss
import time
import mss.tools

from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

import torch
import numpy as np
from PIL import Image
import torch.nn as nn

### Initializing Doom Env

In [None]:
class VizDoomGym(Env): 
    def __init__(self, render=False, config='ViZDoom/scenarios/health_gathering.cfg'): 

        super().__init__()

        self.game = vizdoom.DoomGame()
        self.game.load_config(config)
        

        if render == False: 
            self.game.set_window_visible(False)
        else:
            self.game.set_window_visible(True)
        

        self.game.init()
        

        self.observation_space = Box(low=0, high=255, shape=(100,160,1), dtype=np.uint8) 
        self.action_space = Discrete(3)
        
        
    def step(self, action, tics):

        actions = np.identity(3)
        movement_reward = self.game.make_action(actions[action], tics) 
        
        reward = 0

        if self.game.get_state(): 
            state = self.game.get_state().screen_buffer
    
            game_variables = self.game.get_state().game_variables
            health = game_variables
            
            reward = movement_reward 
            info = health
        else: 
            state = np.zeros(self.observation_space.shape)
            info = 0 
        
        info = {"info":info}
        done = self.game.is_episode_finished()
        
        return state, reward, done, info 
    
    def render(): 
        pass
    
    def reset(self): 
        self.game.new_episode()
        state = self.game.get_state().screen_buffer
        return state
    
    def close(self): 
        self.game.close()

### Performing a single action

In [None]:
env = VizDoomGym(render=True)
tics = 1
states = []
infos = []
for episode in range(1): 
    obs = env.reset()
    done = False
    total_reward = 0
    while not done: 
        action = env.action_space.sample()
        obs, reward, done, info = env.step(1, tics)
        states.append(obs)
        infos.append(info)
        total_reward += reward
    print('Total Reward for episode {} is {}'.format(episode, total_reward))
env.close()

### Testing Trained Model

In [None]:
# Value Head for the VLM model
class VLMWithValueHead(nn.Module):
    def __init__(self, vlm_model, device="cuda"):
        super(VLMWithValueHead, self).__init__()
        self.vlm_model = vlm_model
        self.value_head = nn.Sequential(
            nn.Linear(vlm_model.config.hidden_size, 256, dtype=torch.bfloat16),
            nn.ReLU(),
            nn.Linear(256, 1, dtype=torch.bfloat16)
        ).to(device)
        self.device = device
    
    def forward(self, inputs):
        outputs = self.vlm_model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1][:, -1, :]
        value = self.value_head(hidden_states)
        return outputs, value

In [None]:
device = "mps"
model_path = "./vlm_ppo_model_iter_20.pt"

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map="mps",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")

model = VLMWithValueHead(model, device=device)
model.load_state_dict(torch.load(model_path, map_location=device))


In [None]:
def vlm_action(state, info, model=model, processor=processor, device="mps"):

    try:
        # Ensure state is in the right format
        if state.ndim == 3 and state.shape[0] in [1, 3]:  
            state = np.transpose(state, (1, 2, 0)) 
        elif state.ndim == 2:
            state = np.stack([state] * 3, axis=-1) 
            
        img = Image.fromarray(state.astype(np.uint8))
        
        health = str(info["info"][0].item() if isinstance(info["info"], np.ndarray) else info["info"])
        
        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": img},
                {
                    "type": "text",
                    "text": f"""You are in a game environment where your life is constantly decreasing, collect health 
                    packages present in the environment to survive. You are given with the current state of the game, 
                    you need to choose and action: MOVELEFT, MOVERIGHT, STRAIGHT.

                    Write out the reason why you are choosing an action in one concise sentence and choose an action in the shown format.
                    Current health: {health}.

                    Output format: {{ REASON: (reasoning in one line), ACTION: (one of MOVELEFT, MOVERIGHT, STRAIGHT) }}"""
                }
            ]
        }]
        
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        
        image_inputs, video_inputs = process_vision_info(messages)
        
        inputs = processor(
            text=text,
            images=image_inputs,
            videos=video_inputs,
            return_tensors="pt"
        )
        
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(device)
        
        with torch.no_grad():
            generated_ids = model.vlm_model.generate(
                **inputs,
                max_new_tokens=128,
                do_sample=False
            )
        
        input_length = inputs["input_ids"].shape[1]
        generated_text = processor.decode(
            generated_ids[0, input_length:],
            skip_special_tokens=True
        )
        
        action_text = generated_text.upper()
        
        if "MOVERIGHT" in action_text:
            action = 1
        elif "MOVELEFT" in action_text:
            action = 0
        elif "STRAIGHT" in action_text:
            action = 2
        else:

            action = 2
        
        return action, generated_text
        
    except Exception as e:
        print(f"Error in vlm_action: {e}")
        import traceback
        traceback.print_exc()
        return 2, f"Error occurred: {str(e)}"

In [None]:
env = VizDoomGym(render=True)
replay_buffer = []
tics = 4

for episode in range(2): 
    obs = env.reset()
    done = False
    total_reward = 0
    info = {'info': np.array([100.0])}
    health_info = info['info'][0].item()
    while not done: 
        action, reasoning = vlm_action(obs, info)
        obs, reward, done, info = env.step(action, tics)
        current_health_info = info['info'][0].item()
        if health_info >= current_health_info:
            health_info = current_health_info
            reward = reward - tics
        else:
            reward = 10 * tics
            health_info = current_health_info
        total_reward += reward
        
        print(info, reasoning, reward)
        replay_buffer.append([action, reward, info['info'][0].item(), reasoning])
    print('Total Reward for episode {} is {}'.format(episode, total_reward))
env.close()