In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from torch.amp import autocast
from torch import einsum
import torch.nn.functional as F

import open_clip

from transformers import GPT2LMHeadModel, AutoTokenizer
from transformers import T5ForConditionalGeneration
from typing import Optional

from transformers.optimization import Adafactor
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm

import pickle
from torchmetrics.text import BLEUScore
from evaluate import load
from statistics import mean
import pandas as pd
from einops import rearrange
import math
import wandb
from accelerate import Accelerator
from accelerate.utils import set_seed


device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def exists(val):
    return val is not None


def default(val, d):
    return val if exists(val) else d


def stable_softmax(t, dim=-1):
    t = t - t.amax(dim=dim, keepdim=True)
    return t.softmax(dim=dim)


def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = nnf.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention


def expand_mask(mask):
    assert mask.ndim > 2, "Mask must be at least 2-dimensional with seq_length x seq_length"
    if mask.ndim == 3:
        mask = mask.unsqueeze(1)
    while mask.ndim < 4:
        mask = mask.unsqueeze(0)
    return mask


class BidirectionalCrossAttention(nn.Module):
    def __init__(
            self,
            *,
            dim,
            heads=8,
            dim_head=64,
            context_dim=None,
            dropout=0.,
            talking_heads=False,
            prenorm=False,
    ):
        super().__init__()
        context_dim = default(context_dim, dim)
        self.norm = nn.LayerNorm(dim) if prenorm else nn.Identity()
        self.context_norm = nn.LayerNorm(context_dim) if prenorm else nn.Identity()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads
        self.dropout = nn.Dropout(dropout)
        self.context_dropout = nn.Dropout(dropout)
        self.to_qk = nn.Linear(dim, inner_dim, bias=False)
        self.context_to_qk = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(dim, inner_dim, bias=False)
        self.context_to_v = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)
        self.context_to_out = nn.Linear(inner_dim, context_dim)
        self.talking_heads = nn.Conv2d(heads, heads, 1, bias=False) if talking_heads else nn.Identity()
        self.context_talking_heads = nn.Conv2d(heads, heads, 1, bias=False) if talking_heads else nn.Identity()

    def forward(
            self,
            x,
            context,
            mask=None,
            context_mask=None,
            return_attn=False,
            rel_pos_bias=None
    ):
        b, i, j, h, device = x.shape[0], x.shape[-2], context.shape[-2], self.heads, x.device
        x = self.norm(x)
        context = self.context_norm(context)
        qk, v = self.to_qk(x), self.to_v(x)
        context_qk, context_v = self.context_to_qk(context), self.context_to_v(context)
        qk, context_qk, v, context_v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
                                           (qk, context_qk, v, context_v))
        sim = einsum('b h i d, b h j d -> b h i j', qk, context_qk) * self.scale
        if exists(rel_pos_bias):
            sim = sim + rel_pos_bias
        if exists(mask) or exists(context_mask):
            mask = default(mask, torch.ones((b, i), device=device, dtype=torch.bool))
            context_mask = default(context_mask, torch.ones((b, j), device=device, dtype=torch.bool))
            attn_mask = rearrange(mask, 'b i -> b 1 i 1') * rearrange(context_mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
        attn = stable_softmax(sim, dim=-1)
        context_attn = stable_softmax(sim, dim=-2)
        attn = self.dropout(attn)
        context_attn = self.context_dropout(context_attn)
        attn = self.talking_heads(attn)
        context_attn = self.context_talking_heads(context_attn)
        out = einsum('b h i j, b h j d -> b h i d', attn, context_v)
        context_out = einsum('b h j i, b h j d -> b h i d', context_attn, v)
        out, context_out = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (out, context_out))
        out = self.to_out(out)
        context_out = self.context_to_out(context_out)
        if return_attn:
            return out, context_out, attn, context_attn
        return out, context_out


class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.qkv_proj = nn.Linear(input_dim, 3 * embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, _ = x.size()
        if mask is not None:
            mask = expand_mask(mask)
        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3)
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        o = self.o_proj(values)
        if return_attention:
            return o, attention
        else:
            return o


class QFormerBlock(nn.Module):
    def __init__(self, img_emb_size, text_emb_size, output_size, bias=True, act=nn.Tanh):
        super(QFormerBlock, self).__init__()

        self.attn = MultiheadAttention(text_emb_size, text_emb_size, 16)
        self.cross_attn = BidirectionalCrossAttention(
            dim=img_emb_size,
            heads=16,
            dim_head=1024,
            context_dim=text_emb_size
        )
        self.text_mlp = nn.Sequential(
            nn.Linear(text_emb_size, text_emb_size * 2),
            act(),
            nn.Linear(text_emb_size * 2, text_emb_size * 2),
            act(),
            nn.Linear(text_emb_size * 2, output_size)
        )

    @autocast("cuda")
    def forward(self, img_emb: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor:
        text_emb = self.attn(text_emb)
        img_emb, text_emb = self.cross_attn(img_emb.reshape(-1, 1, img_emb.shape[1]), text_emb)
        text_emb = self.text_mlp(text_emb)
        return img_emb, text_emb


class MySequential(nn.Sequential):
    def forward(self, *inp):
        for module in self._modules.values():
            inp = module(*inp)
            if inp[0].shape[1] == 1:
                inp = (inp[0][:, 0, :], inp[1])
        return inp


class QFormer(nn.Module):
    def __init__(self, img_emb_size, text_emb_size, output_size, n_blocks=4, bias=True, act=nn.Tanh):
        super(QFormer, self).__init__()

        self.blocks = MySequential(
            *[QFormerBlock(img_emb_size, text_emb_size, text_emb_size) for _ in range(n_blocks)],
        )
        self.res = nn.Sequential(
            nn.Linear(img_emb_size + text_emb_size, output_size)
        )

    @autocast("cuda")
    def forward(self, img_emb: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor:
        img_emb, text_emb = self.blocks(img_emb, text_emb)
        text_emb = text_emb.mean(axis=1)
        res_emb = torch.cat((img_emb, text_emb), axis=1)
        res_emb = self.res(res_emb)
        return res_emb


class MLP(nn.Module):
    def __init__(self, input_shape, output_shape, act=nn.Tanh):
        super(MLP, self).__init__()
        self.seq = nn.Sequential(
            nn.Linear(input_shape, input_shape * 2),
            act(),
            nn.Linear(input_shape * 2, output_shape)
        )

    @autocast("cuda")
    def forward(self, x):
        return self.seq(x)


def freeze(
        model,
        freeze_emb=False,
        freeze_ln=False,
        freeze_attn=False,
        freeze_ff=False,
        freeze_other=False,
):
    for name, p in model.named_parameters():
        name = name.lower()
        if 'ln' in name or 'norm' in name:
            p.requires_grad = not freeze_ln
        elif 'embeddings' in name:
            p.requires_grad = not freeze_emb
        elif 'mlp' in name:
            p.requires_grad = not freeze_ff
        elif 'attn' in name:
            p.requires_grad = not freeze_attn
        else:
            p.requires_grad = not freeze_other

    return model


class ClipCaptionModel(nn.Module):
    def __init__(self, config, prefix_length: int, prefix_size: int = 640, dist_loss=nn.MSELoss()):
        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length
        self.clip_model, _, _ = open_clip.create_model_and_transforms(config.encoder, pretrained="laion400m_e32")
        self.tokenizer = AutoTokenizer.from_pretrained(config.decoder)
        self.gpt = T5ForConditionalGeneration.from_pretrained(config.decoder,
                                                   eos_token_id=self.tokenizer.pad_token_id)
        self.gpt_embedding_size = self.gpt.get_input_embeddings().weight.shape[1]
        self.clip_project = QFormer(prefix_size, self.gpt_embedding_size,
                                    self.gpt_embedding_size * prefix_length)
        self.device = config.device
        self.dist_loss = dist_loss
        self.mlp = MLP(self.gpt_embedding_size, self.gpt_embedding_size)

        for p in self.gpt.parameters():
            p.requires_grad = False
        for p in self.clip_model.parameters():
            p.requires_grad = False

    def get_text_embeddings(self, tokens):
        with torch.no_grad():
            embedding_text = self.gpt.encoder.forward(input_ids=tokens, return_dict=True)
            embedding_text = embedding_text.last_hidden_state
        return embedding_text
    def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

    @autocast("cuda")
    def forward(self, query_tokens: torch.Tensor, query_mask: Optional[torch.Tensor],
                answer_tokens: torch.Tensor, answer_mask: Optional[torch.Tensor], image):
        with torch.no_grad():
            embedding_text = self.gpt.encoder.forward(input_ids=query_tokens, return_dict=True)
            embedding_text = embedding_text.last_hidden_state
        image = self.clip_model.encode_image(image)
        prefix_projections = self.clip_project(image.float(), embedding_text).view(-1, self.prefix_length,
                                                                                   self.gpt_embedding_size)
        prefix_projections = self.mlp(prefix_projections)
        out = self.gpt(inputs_embeds=prefix_projections, labels=answer_tokens)
        return out, prefix_projections

    def generate(self, image, texts, max_seq_len):
        tokens = torch.tensor(self.tokenizer.batch_encode_plus(texts, padding='max_length', max_length=max_seq_len, truncation=True)['input_ids'], dtype=torch.int64).to(self.device)
        with torch.no_grad():
            embedding_text = self.gpt.encoder.forward(input_ids=tokens, return_dict=True)
            embedding_text = embedding_text.last_hidden_state
        image = self.clip_model.encode_image(image)
        prefix_projections = self.clip_project(image.float(), embedding_text).view(-1, self.prefix_length,
                                                                                   self.gpt_embedding_size)
        prefix_projections = self.mlp(prefix_projections)
        out = self.gpt.generate(
            inputs_embeds=prefix_projections,
            max_new_tokens=self.prefix_length,
            no_repeat_ngram_size=3,
            repetition_penalty=2.,
        )
        res = [decode_question(x, self.tokenizer) for x in out]
        return res

In [None]:
from torch.utils.data import Dataset
import sys
from matplotlib import pyplot as plt
import json
from PIL import Image
class VQAv2_Dataset(Dataset):
    def __init__(self, config, dataset_path, coef_size=0.1,
                 tokenizer_name="", prefix_length=20, normalize_prefix=False, imagespath_split=None):
        if not tokenizer_name:
            tokenizer_name = config.decoder
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        clip_model, _, self.preprocess = open_clip.create_model_and_transforms(config.encoder, pretrained="laion400m_e32")
        self.prefix_length = prefix_length
        self.normalize_prefix = normalize_prefix

        with open(dataset_path, 'r') as f:
            dataset = json.loads(list(f)[0])

        self.img_paths = []
        self.query_tokens = []
        self.answer_tokens = []

        max_img = len(dataset)*coef_size
        for i, el in tqdm(enumerate(dataset), total=max_img):
            answer = el['answer'] 
            question = el['question']
            self.query_tokens += [self.tokenizer.encode(question, return_tensors="pt",padding='max_length', max_length=prefix_length, truncation=True)]
            self.answer_tokens += [self.tokenizer.encode(answer, return_tensors="pt", padding='max_length', max_length=prefix_length, truncation=True)]
            if ("val" in imagespath_split):
                self.img_paths += [imagespath_split + el['image_id'].replace("train", "val") + ".jpg"]
            else:
                self.img_paths += [imagespath_split + el['image_id'] + ".jpg"]
            if int(i) >= max_img:
                  break
        del dataset
        sys.stdout.flush()
        self.max_seq_len = prefix_length

    def pad_tokens(self, item: int):
        query_tokens = self.query_tokens[item]
        answer_tokens = self.answer_tokens[item]
        query_mask = query_tokens
        answer_mask = answer_tokens
        return query_tokens[0], query_mask[0], answer_tokens[0], answer_mask[0]

    def get_image(self, item):
        name = str(self.img_paths[item])
        image_resized = Image.open(name)
        image_resized = image_resized.resize((256, 256))
        return image_resized

    def __len__(self) -> int:
        return len(self.img_paths)

    def __getitem__(self, item):
        image = self.get_image(item)
        image = self.preprocess(image).unsqueeze(0)
        query_tokens, query_mask, answer_tokens, answer_mask = self.pad_tokens(item)
        return query_tokens, query_mask, answer_tokens, answer_mask, image[0], item

    def show_image(self, item):
        image = self.get_image(item)
        text = self.tokenizer.decode(self.pad_tokens(item)[2])
        plt.imshow(image)
        print(text)

In [None]:
class Config:
    encoder: str = "ViT-B-16-plus-240"
    decoder: str = "ai-forever/FRED-T5-large"
    batch_size: int = 128
    num_epochs: int = 40
    frozen_gpt: int = 8
    frozen_clip: int = 24
    learning_rate: float  = 2e-4
    save_path: str = "saved_models_FRED-T5-large/fulliliya/"
    prefix_length: int = 20
    only_prefix: int = False
    prefix: str = "prefix_small"
    device: str = "cuda:1"
    save_every: int = 1
    warmup_steps: int = 2000

In [None]:
bertscore = load("bertscore")
meteor = load('meteor')
rouge = load('rouge')
bleu_scorers = [BLEUScore(n_gram=i) for i in [1, 2, 3]] + [bertscore, meteor, rouge]

In [None]:
wandb.login(key="")
wandb.init(project="baseline-vqa-rugpt", sync_tensorboard=True, name="")

In [None]:
def decode_question(question_token, tokenizer):
    decoded_string = tokenizer.decode(question_token)
    # if "<pad>" in decoded_string:
    #     truncate_pads = decoded_string.index("<pad>")
    #     decoded_string = decoded_string[:truncate_pads]
    decoded_string = decoded_string.replace("<pad>", "")
    return decoded_string

In [None]:
def train(model, optimizer, scheduler, loss_func, loader, epoch, args):
    model.train()
    pbar = tqdm(loader, total=len(loader))
    step = 0
    for (query_tokens, query_mask, answer_tokens, answer_mask, prefix, idx) in pbar:

        query_tokens, query_mask, prefix = query_tokens.to(args.device), query_mask.to(args.device), prefix.to(
            args.device, dtype=torch.bfloat16)
        answer_tokens, answer_mask = answer_tokens.to(args.device), answer_mask.to(args.device)
        # print(query_tokens.size(), query_mask.size(), answer_tokens.size(), answer_mask.size(), prefix.size())
        outputs, proj = model(query_tokens, query_mask, answer_tokens, answer_mask, prefix)
        logits = outputs.logits
        loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), answer_tokens.flatten().to(torch.int64),
                                 ignore_index=0)

        loss2 = model.dist_loss(model.get_text_embeddings(answer_tokens).to(torch.float32), proj.to(torch.float32))
        loss += loss2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2)

        #backpropogation
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        pbar.set_postfix({"loss": loss.item(), "dist_loss": loss2.item()})
        wandb.log({"loss": loss.item(), "dist_loss": loss2.item()})
        step += 1
        if step % 1000 == 0:
            print("QUESTION:", decode_question(query_tokens[0], train_dataset.tokenizer))
            print("ANSWER:", decode_question(answer_tokens[0], train_dataset.tokenizer))
            print("PREDICTED: ", model.generate(torch.tensor([train_dataset[idx[0]][4].tolist()]).to(args.device),
                                                [decode_question(query_tokens[0], model.tokenizer)], train_dataset.max_seq_len)[0])
    with open(f'{args.save_path}checkpoint_{epoch}.pkl', 'wb') as f:
        pickle.dump(model, f)

In [None]:
@torch.no_grad()
def evaluate(model, optimizer, scheduler, loss_func, loader, args):
    model.eval()
    pbar = tqdm(loader, total=len(loader))
    step = 0

    bl1 = []
    bl2 = []
    bl3 = []
    brt = []
    mtr = []
    rg = []
    val_losses = []
    val_dist = []
    for (query_tokens, query_mask, answer_tokens, answer_mask, prefix, idx) in pbar:
        query_tokens, query_mask, prefix = query_tokens.to(args.device), query_mask.to(args.device), prefix.to(
            args.device, dtype=torch.bfloat16)
        answer_tokens, answer_mask = answer_tokens.to(args.device), answer_mask.to(args.device)
        outputs, proj = model(query_tokens, query_mask, answer_tokens, answer_mask, prefix)
        logits = outputs.logits
        loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), answer_tokens.flatten().to(torch.int64),
                                 ignore_index=0)
        loss2 = model.dist_loss(model.get_text_embeddings(answer_tokens), proj)

        # real = model.tokenizer.batch_decode(answer_tokens)
        real = [decode_question(answer_tokens[i], model.tokenizer) for i in range(len(answer_tokens))]
#         pred = model.generate(torch.tensor([val_dataset[idx[j]][4].tolist() for j in range(len(idx))]).to(args.device),
#                               ["Что на картинке?" for _ in range(len(idx))])
        pred = model.generate(torch.tensor([val_dataset[idx[j]][4].tolist() for j in range(len(idx))]).to(args.device),
                              [decode_question(query_tokens[j], model.tokenizer) for j in range(len(idx))], val_dataset.max_seq_len)
        
#         model.generate(torch.tensor([train_dataset[idx[0]][4].tolist()]).to(args.device),
#                                                 decode_question(query_tokens, model.tokenizer))[0]
        
        # real = truncate_sentences(real)
        # pred = truncate_sentences(pred)
        
        bl1.append(bleu_scorers[0](pred, real))
        bl2.append(bleu_scorers[1](pred, real))
        bl3.append(bleu_scorers[2](pred, real))
        brt.append(bleu_scorers[3].compute(predictions=pred, references=real, lang="ru")['f1'])
        mtr.append(bleu_scorers[4].compute(predictions=pred, references=real)['meteor'])
        rg.append(bleu_scorers[5].compute(predictions=pred, references=real)['rougeL'])

        if step % 400 == 0:
            print("QUESTION:", decode_question(query_tokens[0], val_dataset.tokenizer))
            print("TEXT:", real[0])
            print("PREDICTED: ", pred[0])

            imgs = []
            for j in range(len(idx)):
                wa_img = wandb.Image(
                    val_dataset.get_image(idx[j]),
                    caption=f"REAL : {real[j]}, PREDICTED : {pred[j]}"
                )
                imgs.append(wa_img)

            wandb.log({"Generations.": imgs})

        step += 1

        pbar.set_postfix({"val_loss": loss.item(), "val_dist": loss2.item()})
        val_losses.append(loss.item())
        val_dist.append(loss2.item())

    wandb.log({"val_loss": mean(val_losses),
               "val_dist": mean(val_dist)})

    wandb.log({
        "bleu_1": mean([tensor.item() for tensor in bl1]),
        "bleu_2": mean([tensor.item() for tensor in bl2]),
        "bleu_3": mean([tensor.item() for tensor in bl3]),
        "bert_score": np.mean([tensor for tensor in brt]),
        "meteor_score": np.mean([tensor for tensor in mtr]),
        "rouge_score": np.mean([tensor for tensor in rg])
    })


In [None]:
def fit_model(args, model, train_loader, val_loader):
    wandb.config = {
        "learning_rate": args.learning_rate,
        "epochs": args.num_epochs,
        "batch_size": args.batch_size
    }

    # if not os.path.exists(args.save_path):
    #     os.makedirs(args.save_path)
    device = args.device

    # model = ClipCaptionModel(args, args.prefix_length)
    model = model.to(args.device)

    wandb.watch(model, log_freq=10, log="gradients")

    model.train()

    loss_func = nn.CrossEntropyLoss()
    optimizer = Adafactor(model.parameters(), lr=args.learning_rate,
                          relative_step=False  # for adafactor
                          )

    # train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=20, shuffle=True, drop_last=False)
    # val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=20, shuffle=True, drop_last=False)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=15000
    )
    # evaluate(model, optimizer, scheduler, loss_func, val_loader, args)
    print("Start train model")
    for epoch in range(args.num_epochs):
        if epoch == args.frozen_gpt:
            print("GPT UNFROZEN")
            for p in model.gpt.parameters():
                p.requires_grad = True
        if epoch == args.frozen_clip:
            print("CLIP UNFROZEN")
            for p in model.clip_model.parameters():
                p.requires_grad = True
        print(f"---------- Train epoch {epoch} ---------")
        train(model, optimizer, scheduler, loss_func, train_loader, epoch, args)
        print(f"---------- Evaluate epoch {epoch} ---------")
        evaluate(model, optimizer, scheduler, loss_func, val_loader, args)

In [None]:
config = Config()
train_dataset = VQAv2_Dataset(config, dataset_path="VQAv2_train_translation.jsonl", imagespath_split="trainvqa/train2014/", coef_size=0.5)
val_dataset = VQAv2_Dataset(config, dataset_path="VQAv2_val_translation.jsonl", imagespath_split="valvqa/val2014/", coef_size=0.05)

In [None]:
model = ClipCaptionModel(config, config.prefix_length)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, num_workers=20, shuffle=True, drop_last=False)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, num_workers=20, shuffle=True, drop_last=False)

In [None]:
fit_model(config, model, train_loader, val_loader)

In [None]:
embtext = []
imgemb = []

In [None]:
class ClipCaptionModel(nn.Module):
    def __init__(self, config, prefix_length: int, prefix_size: int = 640, dist_loss=nn.MSELoss()):
        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length
        self.clip_model, _, _ = open_clip.create_model_and_transforms(config.encoder, pretrained="laion400m_e32")
        self.tokenizer = AutoTokenizer.from_pretrained(config.decoder)
        self.gpt = T5ForConditionalGeneration.from_pretrained(config.decoder)
                                                #    eos_token_id=self.tokenizer.pad_token_id)
        self.gpt_embedding_size = self.gpt.get_input_embeddings().weight.shape[1]
        self.clip_project = QFormer(prefix_size, self.gpt_embedding_size,
                                    self.gpt_embedding_size * prefix_length)
        self.device = config.device
        self.dist_loss = dist_loss
        self.mlp = MLP(self.gpt_embedding_size, self.gpt_embedding_size)

        for p in self.gpt.parameters():
            p.requires_grad = False
        for p in self.clip_model.parameters():
            p.requires_grad = False

    def get_text_embeddings(self, tokens):
        with torch.no_grad():
            embedding_text = self.gpt.encoder.forward(input_ids=tokens, return_dict=True)
            embedding_text = embedding_text.last_hidden_state
        return embedding_text
    def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

    @autocast("cuda")
    def forward(self, query_tokens: torch.Tensor, query_mask: Optional[torch.Tensor],
                answer_tokens: torch.Tensor, answer_mask: Optional[torch.Tensor], image):
        with torch.no_grad():
            embedding_text = self.gpt.encoder.forward(input_ids=query_tokens, return_dict=True)
            embedding_text = embedding_text.last_hidden_state
        image = self.clip_model.encode_image(image)
        embtext.append(embedding_text)
        imgemb.append(image)
        # print(image, embedding_text)
        prefix_projections = self.clip_project(image.float(), embedding_text).view(-1, self.prefix_length,
                                                                                   self.gpt_embedding_size)
        prefix_projections = self.mlp(prefix_projections)
        out = self.gpt(inputs_embeds=prefix_projections, labels=answer_tokens)
        return out, prefix_projections

    def generate(self, image, texts, max_seq_len):
        tokens = torch.tensor(self.tokenizer.batch_encode_plus(texts, padding='max_length', max_length=max_seq_len, truncation=True)['input_ids'], dtype=torch.int64).to(self.device)
        with torch.no_grad():
            embedding_text = self.gpt.encoder.forward(input_ids=tokens, return_dict=True)
            embedding_text = embedding_text.last_hidden_state
        image = self.clip_model.encode_image(image)
        prefix_projections = self.clip_project(image.float(), embedding_text).view(-1, self.prefix_length,
                                                                                   self.gpt_embedding_size)
        prefix_projections = self.mlp(prefix_projections)
        out = self.gpt.generate(
            inputs_embeds=prefix_projections,
            max_new_tokens=self.prefix_length,
            no_repeat_ngram_size=3,
            repetition_penalty=2.,
        )
        res = [decode_question(x, self.tokenizer) for x in out]
        return res

In [None]:
prefix_length = 20
clip_model, _, _ = open_clip.create_model_and_transforms(config.encoder, pretrained="laion400m_e32")
tokenizer = AutoTokenizer.from_pretrained(config.decoder)
gpt = T5ForConditionalGeneration.from_pretrained(config.decoder,
                                            eos_token_id=tokenizer.pad_token_id)
gpt_embedding_size = gpt.get_input_embeddings().weight.shape[1]
clip_project = QFormer(512, gpt_embedding_size,
                            gpt_embedding_size * prefix_length)
device = config.device
dist_loss=nn.MSELoss()
mlp = MLP(gpt_embedding_size, gpt_embedding_size)

In [None]:
model = ClipCaptionModel(config, 20)
model = model.to(config.device)

In [None]:
pbar = tqdm(train_loader, total=len(train_loader))
query_tokens, query_mask, answer_tokens, answer_mask, prefix, idx = (0, 0, 0, 0, 0, 0)
for (query_tokens, query_mask, answer_tokens, answer_mask, prefix, idx) in pbar:
    query_tokens, query_mask, prefix = query_tokens.to(config.device), query_mask.to(config.device), prefix.to(
        config.device, dtype=torch.bfloat16)
    answer_tokens, answer_mask = answer_tokens.to(config.device), answer_mask.to(config.device)
    break

In [None]:
query_tokens = query_tokens.to(config.device)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(config.decoder)
gpt = T5ForConditionalGeneration.from_pretrained(config.decoder,
                                                   eos_token_id=tokenizer.pad_token_id)
gpt = gpt.to(config.device)
with torch.no_grad():
    embedding_text = gpt.encoder.forward(input_ids=query_tokens, return_dict=True)
    embedding_text = embedding_text.last_hidden_state

In [None]:
query_tokens[51]

In [None]:
embedding_text

In [None]:
outputs, proj = model(query_tokens, query_mask, answer_tokens, answer_mask, prefix)

In [None]:
proj

In [None]:
lm_text='<LM>Принялся Кутузов рассказывать свою историю как он сюда попал. Началось'
input_ids=torch.tensor([tokenizer.encode(lm_text)]).to(device)
outputs=model.generate(input_ids,eos_token_id=tokenizer.eos_token_id,early_stopping=True)
print(tokenizer.decode(outputs[0][1:]))

In [None]:
gpt = gpt.to(config.device)
outputs = gpt.encoder.forward(input_ids=query_tokens, return_dict=True)
embeddings = outputs.last_hidden_state

In [None]:
model = ClipCaptionModel(config, config.prefix_length)

In [None]:
model = model.to(config.device)

In [None]:
outputs, proj = model(query_tokens, query_mask, answer_tokens, answer_mask, prefix)