In [None]:
!pip install torch
!pip install git+https://github.com/huggingface/transformers -qqq
!pip install sentencepiece -qqq
!pip install bitsandbytes -qqq
!pip install accelerate -qqq
!pip uninstall nvidia_cublas_cu11

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m51.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m33.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.6/92.6 MB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━

In [None]:
import torch
from torch import nn
import json

from dataclasses import dataclass, field
from torch.utils.data import Dataset

In [None]:
#!g1.1
from PIL import Image
import os

class JsonDataset(Dataset):
    def __init__(self, filename, img_dir, transform=None, verbose=True):
        super().__init__()

        self.verbose = verbose
        self.data = []
        self.transform = transform
        self.img_dir = img_dir
        # self.blank_image = Image.open(os.listdir(self.img_dir)[0])

        with open(filename) as f:
            # load the JSON file
            json_data = json.load(f)
            # process the loaded JSON file
            for item in json_data:
                promt = item['goal_eng']
                image_path = os.path.join(img_dir, item['image'])

                plan = item['plan']

                self.data.append((promt, image_path, plan))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        promt = self.data[index][0]
        image_path = self.data[index][1]
        image = Image.open(image_path)
        plan = self.data[index][2]
        return promt, image, plan

In [None]:
#!g1.1
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from transformers import CLIPProcessor, CLIPModel
import torch.nn as nn
from PIL import Image


# Define the encoder class
class TextEncoder(nn.Module):
    def __init__(self, model_name):
        super(TextEncoder, self).__init__()

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=31, truncation_side='left')
        self.model = AutoModel.from_pretrained(model_name)
        self.embedding_size = self.model.config.hidden_size

        # Freeze the weights of the model
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, text_sequences):
        inputs = self.tokenizer(text_sequences, padding='max_length', truncation=True, return_tensors='pt')
        outputs = self.model(**inputs)

        # the sequence of hidden-states at the output of the last layer
        last_hidden_states = outputs.last_hidden_state

        return last_hidden_states

class ImageEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
        self.embedding_size = self.model.config.projection_dim

        # Freeze the weights of the model
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, img: Image):
        inputs = self.processor(
            text=["What robot sees"],
            images=img,
            return_tensors="pt",
            padding=True
        )
        outputs = self.model(**inputs)
        return outputs.image_embeds

class GeneralEncoder(nn.Module):
    def __init__(self, encoder_out_dim=128):
        super().__init__()
        self.encoder_out_dim = encoder_out_dim

        self.text_encoder = TextEncoder("bert-base-uncased") # freezed
        self.image_encoder = ImageEncoder() # freezed
        self.text_linear = nn.Linear(self.text_encoder.embedding_size, self.encoder_out_dim)
        self.image_linear = nn.Linear(self.image_encoder.embedding_size, self.encoder_out_dim)

    def forward(self, text, image):
        text_features = self.text_encoder(text)
        image_features = self.image_encoder(image)

        text_output = self.text_linear(text_features)
        image_output = self.image_linear(image_features)

        output = torch.cat((image_output.unsqueeze(1), text_output), dim=1)

        return output