you can save a version of the notebook, then copy and edit the version.

# Useful Library:

[timm](https://github.com/huggingface/pytorch-image-models): pytorch image models

you'd use einops to flatten the tensor.
since you would doing operation separately.

to measure the overall loss, you would consider merging two types of tensor into one.

# Code

## import some library and setup data filepaths

In [2]:
# try to train a model.
# first let's get the data.

# if you want to use the template generated by gpt4, you have to use kaggle.

##############################
## UPLOAD TO KAGGLE DATASET ##
##############################
import os

basePath = "/kaggle/input/agi-computer-control-test-dataset"
videoPath = os.path.join(basePath,"93755268.mp4")
# videoPath = "ffmpeg_output_test_hwaccel.mp4"
metadata_file_path = os.path.join(basePath,"screenshot_and_actions.json")

# keep it simple, just want to scratch the surface.
# we can run this on cpu.
# from PIL import Image
import numpy as np
import cv2
import ast
from pydantic import BaseModel, validator
from typing import Union
try:
    from typing import Literal
except:
    from typing_extensions import Literal # this is a failsafe.

## define some data models

In [3]:
# import pynput
# no such dependency when training.

class HIDActionBase:
    mouse_resolution: int = 1000
    keyboard_action_types = [
        "key_press",
        "key_release",
    ]
    mouse_action_types = [
        "mouse_move",
        "mouse_click",
        "mouse_scroll",
    ]
    action_types = [
        *keyboard_action_types,
        *mouse_action_types,
        # None,  # end of action
        # there is no such thing here. do it externally.
    ]
    mouse_buttons = [
        "Button.left",
        "Button.middle",
        "Button.right",
    ]
    keys = [
        """','""",
        """'.'""",
        """'/'""",
        """';'""",
        """\"'\"""",
        """'['""",
        """']'""",
        """'\\'""",
        """'='""",
        """'-'""",
        """'0'""",
        """'9'""",
        """'8'""",
        """'7'""",
        """'6'""",
        """'5'""",
        """'4'""",
        """'3'""",
        """'2'""",
        """'1'""",
        """'`'""",
        """'a'""",
        """'b'""",
        """'c'""",
        """'d'""",
        """'e'""",
        """'f'""",
        """'g'""",
        """'h'""",
        """'i'""",
        """'j'""",
        """'k'""",
        """'l'""",
        """'m'""",
        """'n'""",
        """'o'""",
        """'p'""",
        """'q'""",
        """'r'""",
        """'s'""",
        """'t'""",
        """'u'""",
        """'v'""",
        """'w'""",
        """'x'""",
        """'y'""",
        """'z'""",
        "Key.alt",
        "Key.alt",
        "Key.alt_r",
        "Key.alt_r",
        "Key.backspace",
        "Key.caps_lock",
        "Key.cmd",
        "Key.cmd",
        "Key.cmd_r",
        "Key.ctrl",
        "Key.ctrl",
        "Key.ctrl_r",
        "Key.delete",
        "Key.down",
        "Key.end",
        "Key.enter",
        "Key.esc",
        "Key.f1",
        "Key.f2",
        "Key.f3",
        "Key.f4",
        "Key.f5",
        "Key.f6",
        "Key.f7",
        "Key.f8",
        "Key.f9",
        "Key.f10",
        "Key.f11",
        "Key.f12",
        "Key.f13",
        "Key.f14",
        "Key.f15",
        "Key.f16",
        "Key.f17",
        "Key.f18",
        "Key.f19",
        "Key.f20",
        "Key.home",
        "Key.left",
        "Key.page_down",
        "Key.page_up",
        "Key.right",
        "Key.shift",
        "Key.shift",
        "Key.shift_r",
        "Key.space",
        "Key.tab",
        "Key.up",
        "Key.media_play_pause",
        "Key.media_volume_mute",
        "Key.media_volume_down",
        "Key.media_volume_up",
        "Key.media_previous",
        "Key.media_next",
    ]

    @staticmethod
    def unshift_keycode(keycode: str) -> Union[str, None]:
        unshift_keycodes = {
            "!": "1",
            "@": "2",
            "#": "3",
            "$": "4",
            "%": "5",
            "^": "6",
            "&": "7",
            "*": "8",
            "(": "9",
            ")": "0",
            "_": "-",
            "+": "=",
            "{": "[",
            "}": "]",
            "|": "\\",
            ":": ";",
            '"': "'",
            "<": ",",
            ">": ".",
            "?": "/",
            "~": "`",
        }
        ctrl_keycodes = {
            "\x01": "a",
            "\x02": "b",
            "\x03": "c",
            "\x04": "d",
            "\x05": "e",
            "\x06": "f",
            "\x07": "g",
            "\x08": "h",
            "\t": "i",
            "\n": "j",
            "\x0b": "k",
            "\x0c": "l",
            "\r": "m",
            "\x0e": "n",
            "\x0f": "o",
            "\x10": "p",
            "\x11": "q",
            "\x12": "r",
            "\x13": "s",
            "\x14": "t",
            "\x15": "u",
            "\x16": "v",
            "\x17": "w",
            "\x18": "x",
            "\x19": "y",
            "\x1a": "z",
            "<219>": "[",
            "<221>": "]",
            "<189>": "-",
            "<187>": "=",
            "<192>": "`",
            "<48>": "0",
            "<49>": "1",
            "<50>": "2",
            "<51>": "3",
            "<52>": "4",
            "<53>": "5",
            "<54>": "6",
            "<55>": "7",
            "<56>": "8",
            "<57>": "9",
            "<220>": "\\",
            "<186>": ";",
            "<222>": "'",
            "<188>": ",",
            "<190>": ".",
            "<191>": "/",
        }
        keycode = unshift_keycodes.get(keycode, ctrl_keycodes.get(keycode, keycode))
        # still, this is something out of concern.
        if keycode.startswith("<") and keycode.endswith(">"):
            print("Discarding unconvertable keycode: %s" % keycode)
            # keycode = pynput.keyboard.KeyCode(int(keycode[1:-1]))
            return
        return keycode

    @staticmethod
    def uncover_keycode(keycode: str) -> Union[str, None]:
        if not keycode.startswith("Key."):
            keycode_converted = HIDActionBase.unshift_keycode(
                keycode
                if keycode.startswith("<") and keycode.endswith(">")
                else ast.literal_eval(keycode)
            )
            return keycode_converted
            # this could be None.
            # when this is None, simply skip this code. do not end the conversion. skip it.
        else:
            return keycode


class HIDAction(BaseModel, HIDActionBase):
    # static method: from_action
    # static method: from_ndarray
    # instance method: to_ndarray
    # instance method: to_action
    max_x: int
    max_y: int
    action_type: Union[
        Literal["key_press"],  # ["key_press", "'w'"]
        Literal["key_release"],  # ["key_release", "'r'"]
        Literal[
            "mouse_move"
        ],  # ["mouse_move", [176.7734375, 580.40625]], "timeStamp": 1680247557.125498}
        Literal[
            "mouse_click"
        ],  # ["mouse_click", [176.7734375, 580.40625, "Button.left", true]]
        Literal["mouse_scroll"],  # ["mouse_scroll", [938.76171875, 318.75, 0, 0]]
#         None,  # end_of_action
    ] # you need to specify this.
    key: Union[
        Literal["""','"""],
        Literal["""'.'"""],
        Literal["""'/'"""],
        Literal["""';'"""],
        Literal["""\"'\""""],
        Literal["""'['"""],
        Literal["""']'"""],
        Literal["""'\\'"""],
        Literal["""'='"""],
        Literal["""'-'"""],
        Literal["""'0'"""],
        Literal["""'9'"""],
        Literal["""'8'"""],
        Literal["""'7'"""],
        Literal["""'6'"""],
        Literal["""'5'"""],
        Literal["""'4'"""],
        Literal["""'3'"""],
        Literal["""'2'"""],
        Literal["""'1'"""],
        Literal["""'`'"""],
        Literal["""'a'"""],
        Literal["""'b'"""],
        Literal["""'c'"""],
        Literal["""'d'"""],
        Literal["""'e'"""],
        Literal["""'f'"""],
        Literal["""'g'"""],
        Literal["""'h'"""],
        Literal["""'i'"""],
        Literal["""'j'"""],
        Literal["""'k'"""],
        Literal["""'l'"""],
        Literal["""'m'"""],
        Literal["""'n'"""],
        Literal["""'o'"""],
        Literal["""'p'"""],
        Literal["""'q'"""],
        Literal["""'r'"""],
        Literal["""'s'"""],
        Literal["""'t'"""],
        Literal["""'u'"""],
        Literal["""'v'"""],
        Literal["""'w'"""],
        Literal["""'x'"""],
        Literal["""'y'"""],
        Literal["""'z'"""],
        Literal["Key.alt"],
        Literal["Key.alt"],
        Literal["Key.alt_r"],
        Literal["Key.alt_r"],
        Literal["Key.backspace"],
        Literal["Key.caps_lock"],
        Literal["Key.cmd"],
        Literal["Key.cmd"],
        Literal["Key.cmd_r"],
        Literal["Key.ctrl"],
        Literal["Key.ctrl"],
        Literal["Key.ctrl_r"],
        Literal["Key.delete"],
        Literal["Key.down"],
        Literal["Key.end"],
        Literal["Key.enter"],
        Literal["Key.esc"],
        Literal["Key.f1"],
        Literal["Key.f2"],
        Literal["Key.f3"],
        Literal["Key.f4"],
        Literal["Key.f5"],
        Literal["Key.f6"],
        Literal["Key.f7"],
        Literal["Key.f8"],
        Literal["Key.f9"],
        Literal["Key.f10"],
        Literal["Key.f11"],
        Literal["Key.f12"],
        Literal["Key.f13"],
        Literal["Key.f14"],
        Literal["Key.f15"],
        Literal["Key.f16"],
        Literal["Key.f17"],
        Literal["Key.f18"],
        Literal["Key.f19"],
        Literal["Key.f20"],
        Literal["Key.home"],
        Literal["Key.left"],
        Literal["Key.page_down"],
        Literal["Key.page_up"],
        Literal["Key.right"],
        Literal["Key.shift"],
        Literal["Key.shift"],
        Literal["Key.shift_r"],
        Literal["Key.space"],
        Literal["Key.tab"],
        Literal["Key.up"],
        Literal["Key.media_play_pause"],
        Literal["Key.media_volume_mute"],
        Literal["Key.media_volume_down"],
        Literal["Key.media_volume_up"],
        Literal["Key.media_previous"],
        Literal["Key.media_next"],
        None,
    ] = None

    mouse_button: Union[
        Literal["Button.left"], Literal["Button.middle"], Literal["Button.right"], None
    ] = None
    mouse_pressed: Union[bool, None] = None
    x: Union[float, None] = None
    y: Union[float, None] = None
    dx: Union[float, None] = None
    dy: Union[float, None] = None

    @validator("max_x", "max_y")
    def greater_than_zero(cls, v):
        assert type(v) == int
        assert v > 0
        return v

    @validator("action_type")
    def action_type_within_action_types(cls, v):
        if v:
            assert v in HIDActionBase.action_types
        return v

    @validator("key")
    def key_within_keys(cls, v):
        if v:
            assert v in HIDActionBase.keys
        return v

    @validator("mouse_button")
    def mouse_button_within_mouse_buttons(cls, v):
        if v:
            assert v in HIDActionBase.mouse_buttons
        return v

    @validator("mouse_pressed")
    def mouse_pressed_type_check(cls, v):
        if v:
            assert type(v) == bool
        return v

    @staticmethod
    def from_action_json(action_json: list, max_x: int, max_y: int):
        action_type = action_json[0]
        action_args = action_json[1]

        construct_args = dict(max_x=max_x, max_y=max_y, action_type=action_type)

        if action_type == "key_press":
            assert action_args in HIDActionBase.keys

            construct_args.update(dict(key=action_args))
        elif action_type == "key_release":
            assert action_args in HIDActionBase.keys

            construct_args.update(dict(key=action_args))
        elif action_type == "mouse_move":
            assert action_args[0] >= 0 and action_args[0] <= max_x
            assert action_args[1] >= 0 and action_args[1] <= max_y

            construct_args.update(dict(x=action_args[0], y=action_args[1]))
        elif action_type == "mouse_click":
            assert action_args[0] >= 0 and action_args[0] <= max_x
            assert action_args[1] >= 0 and action_args[1] <= max_y
            assert action_args[2] in HIDActionBase.mouse_buttons
            assert type(action_args[3]) == bool

            construct_args.update(
                dict(
                    x=action_args[0],
                    y=action_args[1],
                    mouse_button=action_args[2],
                    mouse_pressed=action_args[3],
                )
            )
        elif action_type == "mouse_scroll":
            assert action_args[0] >= 0 and action_args[0] <= max_x
            assert action_args[1] >= 0 and action_args[1] <= max_y
            assert action_args[2] >= -max_x and action_args[2] <= max_x
            assert action_args[3] >= -max_y and action_args[3] <= max_y

            construct_args.update(
                dict(
                    x=action_args[0],
                    y=action_args[1],
                    dx=action_args[2],
                    dy=action_args[3],
                )
            )
        else:
            raise Exception(
                "Unknown action type: %s\naction args: %s" % (action_type, action_args)
            )

        mHIDAction = HIDAction(**construct_args)
        return mHIDAction

    @staticmethod
    def from_ndarray(ndarray: np.ndarray, max_x: int, max_y: int):
        assert ndarray.shape == (
            len(HIDActionBase.action_types)
            + len(HIDActionBase.keys)
            + len(HIDActionBase.mouse_buttons)
            + 1  # mouse pressed
            + 4 * HIDActionBase.mouse_resolution,
            1,
        )
        cursor = 0

        action_type_ndarray = ndarray[cursor : cursor + len(HIDActionBase.action_types)]
        cursor += len(HIDActionBase.action_types)
        action_type_index = np.argmax(action_type_ndarray)
        action_type = HIDActionBase.action_types[action_type_index]
        del action_type_ndarray
        del action_type_index

        construct_args = dict(max_x=max_x, max_y=max_y, action_type=action_type)

        if action_type:
            key_ndarray = ndarray[cursor : cursor + len(HIDActionBase.keys)]
            cursor += len(HIDActionBase.keys)
            key_index = np.argmax(key_ndarray)
            key = HIDActionBase.keys[key_index]
            del key_ndarray
            del key_index

            mouse_button_ndarray = ndarray[
                cursor : cursor + len(HIDActionBase.mouse_buttons)
            ]
            cursor += len(HIDActionBase.mouse_buttons)
            mouse_button_index = np.argmax(mouse_button_ndarray)
            mouse_button = HIDActionBase.mouse_buttons[mouse_button_index]
            del mouse_button_ndarray
            del mouse_button_index

            mouse_pressed_ndarray = ndarray[cursor : cursor + 1]
            cursor += 1
            mouse_pressed = bool(mouse_pressed_ndarray[0][0])
            del mouse_pressed_ndarray

            x_ndarray = ndarray[cursor : cursor + HIDActionBase.mouse_resolution]
            cursor += HIDActionBase.mouse_resolution
            x_index = np.argmax(x_ndarray)
            x = (x_index / HIDActionBase.mouse_resolution) * max_x
            del x_ndarray
            del x_index

            y_ndarray = ndarray[cursor : cursor + HIDActionBase.mouse_resolution]
            cursor += HIDActionBase.mouse_resolution
            y_index = np.argmax(y_ndarray)
            y = (y_index / HIDActionBase.mouse_resolution) * max_y
            del y_ndarray
            del y_index

            dx_ndarray = ndarray[cursor : cursor + HIDActionBase.mouse_resolution]
            cursor += HIDActionBase.mouse_resolution
            dx_index = np.argmax(dx_ndarray)
            dx = (dx_index / HIDActionBase.mouse_resolution) * 2 * max_x - max_x
            del dx_ndarray
            del dx_index

            dy_ndarray = ndarray[cursor : cursor + HIDActionBase.mouse_resolution]
            cursor += HIDActionBase.mouse_resolution
            dy_index = np.argmax(dy_ndarray)
            dy = (dy_index / HIDActionBase.mouse_resolution) * 2 * max_y - max_y
            del dy_ndarray
            del dy_index

            if action_type == "key_press":
                construct_args.update(dict(key=key))
            elif action_type == "key_release":
                construct_args.update(dict(key=key))
            elif action_type == "mouse_move":
                construct_args.update(dict(x=x, y=y))
            elif action_type == "mouse_click":
                construct_args.update(
                    dict(
                        x=x, y=y, mouse_button=mouse_button, mouse_pressed=mouse_pressed
                    )
                )
            elif action_type == "mouse_scroll":
                construct_args.update(dict(x=x, y=y, dx=dx, dy=dy))
        else:
            pass

        del cursor

        mHIDAction = HIDAction(**construct_args)
        return mHIDAction

    def round_within(self, number: Union[int, float], number_name: str) -> int:
        result = round(number)
        if result > self.mouse_resolution:
            print(f"Warning: {number_name} overflow")
            print(f"Value {result} greater than {self.mouse_resolution}")
            return self.mouse_resolution
        elif result < 0:
            print(f"Warning: {number_name} overflow")
            print(f"Value {result} smaller than 0")
            return 0
        return result

    def to_ndarray(
        self,
    ) -> np.ndarray:
        action_type_ndarray = np.zeros((len(self.action_types), 1))
        action_type_ndarray[self.action_types.index(self.action_type)] = 1

        key_ndarray = np.zeros((len(self.keys), 1))
        if self.key:
            key_ndarray[self.keys.index(self.key)] = 1

        mouse_button_ndarray = np.zeros((len(self.mouse_buttons), 1))
        if self.mouse_button:
            mouse_button_ndarray[self.mouse_buttons.index(self.mouse_button)] = 1

        mouse_pressed_array = np.zeros((1, 1))
        if self.mouse_pressed:
            mouse_pressed_array[0] = 1

        x_ndarray = np.zeros((self.mouse_resolution, 1))
        if self.x:
            x_ndarray[
                self.round_within(self.mouse_resolution * self.x / self.max_x, "X")
            ] = 1

        y_ndarray = np.zeros((self.mouse_resolution, 1))
        if self.y:
            y_ndarray[
                self.round_within(self.mouse_resolution * self.y / self.max_y, "Y")
            ] = 1

        dx_ndarray = np.zeros((self.mouse_resolution, 1))
        if self.dx:
            dx_ndarray[
                self.round_within(
                    self.mouse_resolution * (self.dx + self.max_x) / (2 * self.max_x),
                    "DX",
                )
            ] = 1

        dy_ndarray = np.zeros((self.mouse_resolution, 1))
        if self.dy:
            dy_ndarray[
                self.round_within(
                    self.mouse_resolution * (self.dy + self.max_y) / (2 * self.max_y),
                    "DY",
                )
            ] = 1

        ndarray = np.concatenate(
            [
                action_type_ndarray,
                key_ndarray,
                mouse_button_ndarray,
                mouse_pressed_array,
                x_ndarray,
                y_ndarray,
                dx_ndarray,
                dy_ndarray,
            ]
        )
        return ndarray

    def to_action_json(
        self,
    ) -> Union[list, None]:
        action_type = self.action_type
        if action_type:
            if action_type == "key_press":
                assert self.key in self.keys

                action_args = self.key
            elif action_type == "key_release":
                assert self.key in self.keys

                action_args = self.key
            elif action_type == "mouse_move":
                assert self.x >= 0 and self.x <= self.max_x
                assert self.y >= 0 and self.y <= self.max_y

                action_args = [self.x, self.y]
            elif action_type == "mouse_click":
                assert self.x >= 0 and self.x <= self.max_x
                assert self.y >= 0 and self.y <= self.max_y
                assert self.mouse_button in self.mouse_buttons
                assert type(self.mouse_pressed) == bool

                action_args = [self.x, self.y, self.mouse_button, self.mouse_pressed]
            elif action_type == "mouse_scroll":
                assert self.x >= 0 and self.x <= self.max_x
                assert self.y >= 0 and self.y <= self.max_y
                assert self.dx >= -self.max_x and self.dx <= self.max_x
                assert self.dy >= -self.max_y and self.dy <= self.max_y

                action_args = [self.x, self.y, self.dx, self.dy]
            else:
                raise Exception("Unknown action_type: %s" % action_type)
            action_json = [action_type, action_args]
        else:
            action_json = None
        return action_json


## read data from path

In [4]:

import json

with open(metadata_file_path, "r") as f:
    data = json.loads(f.read())

# print(data.keys())
# ['screenshot_and_actions', 'perspective_size', 'timestep']

perspective_width, perspective_height = data["perspective_size"]
timestep = data["timestep"]

print("PERSPECTIVE:", perspective_width, perspective_height)
print("TIMESTEP:", timestep)

import re

class VideoCaptureContextManager:
    def __init__(self, videoPath):
        self.videoPath = videoPath
        
    def __enter__(self):
        print("Entering the context...")
        self.cap = cv2.VideoCapture(self.videoPath)
        return self.cap

    def __exit__(self, exc_type, exc_value, exc_tb):
        try:
            self.cap.release()
        finally:
            print("Leaving the context...")
        #  print(exc_type, exc_value, exc_tb, sep="\n")


PERSPECTIVE: 1280 800
TIMESTEP: 0.03


## test and define the model

In [5]:
import torch
import torchvision

In [6]:
vit_model = torchvision.models.vit_b_16(pretrained=True)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "
Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth


  0%|          | 0.00/330M [00:00<?, ?B/s]

In [None]:

import numpy as np

class ConsciousBase:
    data_types = ['image', 'HIDAction']
    special_tokens = ['image_newline', 'image_end', 'action_end', None]
    data_type_mappings = {t:np.array([i]) for i, t in enumerate(data_types)}
    special_token_mappings = {t:np.array(list(divmod(i,2))) for i, t in enumerate(special_tokens)}
    vector_size = 1+2+1000+4110 # visual things are pre-encoded. no raw image here!
#     vector_size = 1+2+224*224*3+4110

# can it be consciousnessless?

class ConsciousBlock(BaseModel):
    data_type:Union[Literal["image"],
                   Literal["HIDAction"]] # 1 bit, required
    special_token:Union[Literal["image_newline"],
                       Literal["image_end"],
                       Literal["action_end"],
                       None]=None # 2 bits
    image_data:Union[None, np.ndarray] = None # what is the shape of this image data? assume to be [224,224,3]
    action_data:Union[None, np.ndarray] = None # assume to be: [1, 4110]
    # [1,1000] -> [3,1000,1000] -> [3,224,224]
    #    einsum.repeat       conv2d
    
    # so, maybe you still need some ViT decode layer.
    
    @staticmethod 
    def from_json(data:Mapping):
        mConsciousBlock = ConsciousBlock(**data)
        return mConsciousBlock

    def to_json(self) -> Mapping:
        mJson = self.dict()
        return mJson
    
    @staticmethod
    def from_ndarray(ndarray:np.ndarray):
        # check its shape.
        return mConsciousBlock
    
    def to_ndarray(self) -> np.ndarray:
        return mNDArray
    
class ConsciousFlow(BaseModel, ConsciousBase):
    consciousBlocks:List[ConsciousBlock]
        
    @staticmethod
    def from_json(data:List[Mapping]):
        mList = [ConsciousBlock.from_json(j) for j in data]
        mConsciousFlow = ConsciousFlow(mList)
        return mConsciousFlow

    def to_json(self) -> List[Mapping]:
        mJson = [c.to_json() for c in self.consciousBlocks]
        return mJson
    
    @staticmethod
    def from_ndarray(ndarray:np.ndarray):
        consciousBlockCount, vector_length = ndarray.shape
        assert vector_length ==  ConsciousBase.vector_length
        mConsciousBlocks = []
        for i in range(consciousBlockCount):
            arr = ndarray[i,:]
            mConsciousBlock = ConsciousBlock.from_ndarray(arr)
            mConsciousBlocks.append(mConsciousBlock)
        mConsciousFlow = ConsciousFlow(mConsciousBlocks)
        return mConsciousFlow
    
    def to_ndarray(self) -> np.ndarray:
        mNDArray = np.array([c.to_ndarray() for c in self.consciousBlocks])
        return mNDArray

In [None]:
# notice: when in online mode only image will be backpropagated.
# like using some upside down mirror.

class CustomModel(torch.nn.Module):
    def __init__(self, vit_model, hidden_size_vit, output_size, vit_times = 4, vit_block_size=228):
#     def __init__(self, rwkv_model, vit_model, tokenizer, hidden_size_rwkv, hidden_size_vit, output_size, vit_times = 4, vit_block_size=228):
        super(CustomModel, self).__init__()
#         self.rwkv_model = rwkv_model # processing language, generate actions.
        self.vit_model = vit_model
        self.vit_block_size = vit_block_size # this is default.
        self.vit_times = vit_times
        
        # seq2seq alike.
#         self.hidden_size = hidden_size_rwkv+hidden_size_vit
        self.hidden_size = hidden_size_vit
    
        self.HIDEncoder = torch.nn.Linear()
        self.HIDDecoder = torch.nn.Linear() # use torch.where or something 
        # sparse matrix?
        
        self.ViTDecoder = torch.nn.Conv2d()
    
        self.rnn = torch.nn.LSTM(input_size=1000, hidden_size = hidden_size_vit, batch_first=True)

        # use tensor.
    def forward(self, conscious_stream:torch.Tensor, last_output, rnn_hidden_state=None):
        # conscious_stream: [batch_size, data_type+special_tokens+image_bits+action_bits]
        # you need another pydantic model for it.
#     def forward(self, instruction:str, screenshot:np.ndarray, last_output = None, rnn_hidden_state=None):
#         inputs = self.tokenizer.encode(instruction)
#         lm_output, lm_state = self.rwkv_model.forward(inputs.ids)

# screenshots shall be of (16, 3, 224, 224)
        
        desired_size = self.vit_block_size*self.vit_times
        
#         screenshot_new = resizeImage(screenshot, desired_size)
        # this should be resized ahead.
        vit_output_resized = self.vit_model(screenshots)
#         _, (h0, c0) = self.rnn_encoder_rwkv(lm_state_resized)
        _, (h1, c1) = self.rnn_encoder_vit(vit_output_resized)
        
#         h = torch.cat((h0,h1), dim=1)
#         c = torch.cat((c0,c1), dim=1)
        decoder_output, rnn_hidden_state = self.rnn_decoder(last_output, rnn_hidden_state)
        
#         rnn_output, rnn_hidden_state = self.rnn_decoder(combined_features.unsqueeze(1), (h,c))
        
        output = self.fc(decoder_output)

        return output, rnn_hidden_state

## training

In [8]:
# train the model.
rnn_hidden_state = None # because this is needed for every session per video.

hidden_size_rwkv = hidden_size_vit = 1024
output_size = 4110

model = CustomModel(vit_model, hidden_size_vit, output_size)


In [9]:
from torch.optim import Adam
# from torch.nn import BCELoss
# from torch.nn import MSELoss
from torch.nn import CrossEntropyLoss
lr = 0.00001
optimizer= Adam(model.parameters(), lr=lr) 

# loss_fn = BCELoss(reduction='none')
# loss_fn = MSELoss()
loss_fn = CrossEntropyLoss(reduction='sum')

In [10]:
import cv2

def resizeImage(im, desired_size):
    # im = cv2.imread(im_pth)
    old_size = im.shape[:2] # old_size is in (height, width) format

    ratio = float(desired_size)/max(old_size)
    new_size = tuple([int(x*ratio) for x in old_size])

    # new_size should be in (width, height) format

    im = cv2.resize(im, (new_size[1], new_size[0]))

    delta_w = desired_size - new_size[1]
    delta_h = desired_size - new_size[0]
    top, bottom = delta_h//2, delta_h-(delta_h//2)
    left, right = delta_w//2, delta_w-(delta_w//2)

    color = [0, 0, 0]
    new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT,
        value=color)
    return new_im

In [11]:
# just for test.
# shall we do some backward?

# consider merging multiple ViT converted images to batch them?
# from PIL import Image

desired_size = 224*4

# for every image it adds 16 conscious block in total.
# it could be a huge array

context_length = 20
batch_size = 10

# technically we would "shuffle" the dataset.
# but now, we just read it in sequence.

with VideoCaptureContextManager(videoPath) as cap:
    last_index = -1
    for screenshot_and_actions in data["screenshot_and_actions"]:
        screenshot = screenshot_and_actions["screenshot"]
        actions = screenshot_and_actions["actions"]
        print()
        print("SCREENSHOT:", screenshot)
        image_path = screenshot["imagePath"]
        image_size = screenshot["imageSize"]
        index = re.findall(r"\d+", image_path)[0]
        index = int(index)
        print("IMAGE INDEX:", index)
        
        encoded_actions = []

        for action in actions:
            action_type, action_args = action["HIDEvent"]
            if action_type in HIDActionBase.keyboard_action_types:
                action_args = HIDActionBase.uncover_keycode(action_args)
                if action_args is None:
                    print("Skipping")
                    continue
            mHIDAction = HIDAction.from_action_json(
                [action_type, action_args],
                max_x=perspective_width,
                max_y=perspective_height,
            )  # related to mouse coordinates.
            mHIDActionNDArray = mHIDAction.to_ndarray()
            print(mHIDActionNDArray.shape)


        if index != last_index:
            ret, image = cap.read()
#             print(image.shape) # 400, 640, 3
# shall be the lowest quality.
            # loop through all actions in this area.
            image_resized = resizeImage(image, desired_size)
            image_reshaped = np.rollaxis(image_resized, 2, 0) # (3, 896, 896)
            image_sliced = np.array([
                image_reshaped[:,\
                                                   x*224:(x+1)*224,\
                                                   y*224:(y+1)*224
                                                  ] for x in range(4) for y in range(4)
            ])
            
            # IMAGE RESHAPED: (896, 3, 896)?
            # IMAGE RESHAPED: (896, 896, 3)
#             print('IMAGE RESHAPED:', image_reshaped.shape)
            print('IMAGE SLICED:', image_sliced.shape)
#     (16, 3, 224, 224)
# hell?
            last_output = torch.zeros(1, output_size)
            for EA in encoded_actions:
#                 print('ACTION SHAPE?', EA.to_ndarray().shape) # (4111, 1)
                if last_output is not None:
                    last_output = last_output.detach()
                if rnn_hidden_state is not None:
                    rnn_hidden_state = (rnn_hidden_state[0].detach(), rnn_hidden_state[1].detach())
                last_output, rnn_hidden_state = model.forward(torch.tensor(image_sliced).float(), last_output, rnn_hidden_state)
#                 output, rnn_hidden_state = model.forward(instruction, pil_image, rnn_hidden_state)
                print("OUTPUT:", last_output.shape)
                print("HIDDEN_STATE:", rnn_hidden_state[0].shape, rnn_hidden_state[1].shape) # used for next action, if possible.
                target = torch.tensor(EA.to_ndarray().reshape(1,-1)).float()
                print("TARGET SHAPE:", target.shape)
                print("OUTPUT SHAPE:", last_output.shape)
                # what the heck?
#                 TARGET SHAPE: torch.Size([1, 4111])
#                 OUTPUT SHAPE: torch.Size([1, 4111])
                print(target)
                print(last_output)
#                 with torch.enable_grad():
                loss = loss_fn(last_output, target)
                print("LOSS?")
                print(loss)
            # this loss is incorrect. shall use some argmax stuff.
            # to ensure that this thing is the thing that we want.
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
            del image
            del image_reshaped
            del image_resized
            last_index = index
        else:
            pass
        # no need to jump to index. it is monotonic.
        # image = read_image_from_video_at_given_index_as_ndarray(index, cap)
        # SCREENSHOT: {'timeStamp': 1680247598.852561, 'imagePath': 'screenshots/498.raw', 'imageSize': [2560, 1600]}
        # print(image.shape)  # (1600, 2560, 3)
        # width, height, channel
        # breakpoint()
        print("ACTIONS:", actions)
        # keyboard action types: ['key_press', 'key_release']
        # mouse action types: ['mouse_move', 'mouse_click', 'mouse_scroll']
        # cv2.imshow("IMG", image)
        # cv2.waitKey(0)        
        # del image
        
        # not deleting the image. keep it!

        # now the data reader is complete. focus on the network design.
        # shall we?

Entering the context...

SCREENSHOT: {'timeStamp': 1680247551.452273, 'imagePath': 'screenshots/0.raw', 'imageSize': [2560, 1600]}
IMAGE INDEX: 0
IMAGE SLICED: (16, 3, 224, 224)
ACTIONS: []

SCREENSHOT: {'timeStamp': 1680247551.510298, 'imagePath': 'screenshots/1.raw', 'imageSize': [2560, 1600]}
IMAGE INDEX: 1
IMAGE SLICED: (16, 3, 224, 224)
ACTIONS: []

SCREENSHOT: {'timeStamp': 1680247551.510298, 'imagePath': 'screenshots/1.raw', 'imageSize': [2560, 1600]}
IMAGE INDEX: 1
ACTIONS: []

SCREENSHOT: {'timeStamp': 1680247551.562409, 'imagePath': 'screenshots/2.raw', 'imageSize': [2560, 1600]}
IMAGE INDEX: 2
IMAGE SLICED: (16, 3, 224, 224)
ACTIONS: []

SCREENSHOT: {'timeStamp': 1680247551.562409, 'imagePath': 'screenshots/2.raw', 'imageSize': [2560, 1600]}
IMAGE INDEX: 2
ACTIONS: []

SCREENSHOT: {'timeStamp': 1680247551.621871, 'imagePath': 'screenshots/3.raw', 'imageSize': [2560, 1600]}
IMAGE INDEX: 3
IMAGE SLICED: (16, 3, 224, 224)
ACTIONS: []

SCREENSHOT: {'timeStamp': 1680247551.621871