# Download the GPT-2 pretrained weights

In [None]:
!curl --output gpt2-pytorch_model.bin https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  522M  100  522M    0     0  55.7M      0  0:00:09  0:00:09 --:--:-- 53.4M


# Implementation of VisualGPT 

## `models/transformer/utils.py`

In [None]:
import torch
from torch import nn
from torch.nn import functional as F


def position_embedding(input, d_model):
    input = input.view(-1, 1)
    dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1)
    sin = torch.sin(input / 10000 ** (2 * dim / d_model))
    cos = torch.cos(input / 10000 ** (2 * dim / d_model))

    out = torch.zeros((input.shape[0], d_model), device=input.device)
    out[:, ::2] = sin
    out[:, 1::2] = cos
    return out


def sinusoid_encoding_table(max_len, d_model, padding_idx=None):
    pos = torch.arange(max_len, dtype=torch.float32)
    out = position_embedding(pos, d_model)

    if padding_idx is not None:
        out[padding_idx] = 0
    return out


class PositionWiseFeedForward(nn.Module):
    '''
    Position-wise feed forward layer
    '''

    def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False):
        super(PositionWiseFeedForward, self).__init__()
        self.identity_map_reordering = identity_map_reordering
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(p=dropout)
        self.dropout_2 = nn.Dropout(p=dropout)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, input):
        if self.identity_map_reordering:
            out = self.layer_norm(input)
            out = self.fc2(self.dropout_2(F.relu(self.fc1(out))))
            out = input + self.dropout(torch.relu(out))
        else:
            out = self.fc2(self.dropout_2(F.relu(self.fc1(input))))
            out = self.dropout(out)
            out = self.layer_norm(input + out)
        return out

## `utils/typing.py`

In [None]:
from typing import Union, Sequence, Tuple
import torch

TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor]
TensorOrNone = Union[torch.Tensor, None]

## `models/container.py`

In [None]:
from contextlib import contextmanager
from torch import nn


class Module(nn.Module):
    def __init__(self):
        super(Module, self).__init__()
        self._is_stateful = False
        self._state_names = []
        self._state_defaults = dict()

    def register_state(self, name: str, default: TensorOrNone):
        self._state_names.append(name)
        if default is None:
            self._state_defaults[name] = None
        else:
            self._state_defaults[name] = default.clone().detach()
        self.register_buffer(name, default)

    def states(self):
        for name in self._state_names:
            yield self._buffers[name]
        for m in self.children():
            if isinstance(m, Module):
                yield from m.states()

    def apply_to_states(self, fn):
        for name in self._state_names:
            self._buffers[name] = fn(self._buffers[name])
        for m in self.children():
            if isinstance(m, Module):
                m.apply_to_states(fn)

    def _init_states(self, batch_size: int):
        for name in self._state_names:
            if self._state_defaults[name] is None:
                self._buffers[name] = None
            else:
                self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
                self._buffers[name] = self._buffers[name].unsqueeze(0)
                self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:]))
                self._buffers[name] = self._buffers[name].contiguous()

    def _reset_states(self):
        for name in self._state_names:
            if self._state_defaults[name] is None:
                self._buffers[name] = None
            else:
                self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)

    def enable_statefulness(self, batch_size: int):
        for m in self.children():
            if isinstance(m, Module):
                m.enable_statefulness(batch_size)
        self._init_states(batch_size)
        self._is_stateful = True
        # self._is_stateful = False

    def disable_statefulness(self):
        for m in self.children():
            if isinstance(m, Module):
                m.disable_statefulness()
        self._reset_states()
        self._is_stateful = False

    @contextmanager
    def statefulness(self, batch_size: int):
        self.enable_statefulness(batch_size)
        try:
            yield
        finally:
            self.disable_statefulness()


class ModuleList(nn.ModuleList, Module):
    pass


class ModuleDict(nn.ModuleDict, Module):
    pass

## `models/transformers/attention.py`

In [None]:
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter

class Conv1D(nn.Module):
    def __init__(self, nf, nx):
        super(Conv1D, self).__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        # nn.init.normal_(w, std=0.02)
        nn.init.xavier_uniform_(w)
        self.weight = Parameter(w)
        self.bias = Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        # test = x.contiguous().view(-1, x.size(-1))
        x = torch.addmm(self.bias, x.contiguous().view(-1, x.size(-1)), self.weight.transpose(1,0))

        x = x.view(*size_out)
        return x



class ScaledDotProductAttention(nn.Module):
    '''
    Scaled dot-product attention
    '''

    def __init__(self, d_model, d_k, d_v, h):
        '''
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        '''
        super(ScaledDotProductAttention, self).__init__()
        self.fc_q = nn.Linear(d_model, h * d_k)
        self.fc_k = nn.Linear(d_model, h * d_k)
        self.fc_v = nn.Linear(d_model, h * d_v)

        self.fc_o = nn.Linear(h * d_v, d_model)



        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h

        self.c_attn_query = Conv1D(768 , 768)
        self.c_attn_key = Conv1D(768, 768)
        self.c_attn_value = Conv1D(768, 768)

        self.split_size = 768
        self.n_head = 12
        self.init_weights()

        self.flag = None


    def init_weights(self):
        nn.init.xavier_uniform_(self.fc_q.weight)
        nn.init.xavier_uniform_(self.fc_k.weight)
        nn.init.xavier_uniform_(self.fc_v.weight)


        nn.init.xavier_uniform_(self.fc_o.weight)
        nn.init.constant_(self.fc_q.bias, 0)
        nn.init.constant_(self.fc_k.bias, 0)
        nn.init.constant_(self.fc_v.bias, 0)



        nn.init.constant_(self.fc_o.bias, 0)

        # nn.init.xavier_uniform_(self.c_attn_query.weight)
        # nn.init.xavier_uniform_(self.c_attn_key.weight)
        # nn.init.xavier_uniform_(self.c_attn_value.weight)
        #
        # nn.init.constant_(self.c_attn_query.bias,0)
        # nn.init.constant_(self.c_attn_key.bias,0)
        # nn.init.constant_(self.c_attn_value.bias,0 )


    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
        else:
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)


    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        '''
        Computes
        :param queries: Queries (b_s, nq, d_model)
        :param keys: Keys (b_s, nk, d_model)
        :param values: Values (b_s, nk, d_model)
        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        :return:
        '''

        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]




        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)


        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)
        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att = att.masked_fill(attention_mask, -np.inf)
        att = torch.softmax(att, -1)
        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)
        return out


class ScaledDotProductAttentionMemory(nn.Module):
    '''
    Scaled dot-product attention with memory
    '''

    def __init__(self, d_model, d_k, d_v, h, m):
        '''
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        :param m: Number of memory slots
        '''
        super(ScaledDotProductAttentionMemory, self).__init__()
        self.fc_q = nn.Linear(d_model, h * d_k)
        self.fc_k = nn.Linear(d_model, h * d_k)
        self.fc_v = nn.Linear(d_model, h * d_v)
        self.fc_o = nn.Linear(h * d_v, d_model)
        # self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k))
        # self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v))

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h

        # self.m = m

        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.fc_q.weight)
        nn.init.xavier_uniform_(self.fc_k.weight)
        nn.init.xavier_uniform_(self.fc_v.weight)
        nn.init.xavier_uniform_(self.fc_o.weight)
        # nn.init.normal_(self.m_k, 0, 1 / self.d_k)
        # nn.init.normal_(self.m_v, 0, 1 / self.m)
        nn.init.constant_(self.fc_q.bias, 0)
        nn.init.constant_(self.fc_k.bias, 0)
        nn.init.constant_(self.fc_v.bias, 0)
        nn.init.constant_(self.fc_o.bias, 0)

    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        '''
        Computes
        :param queries: Queries (b_s, nq, d_model)
        :param keys: Keys (b_s, nk, d_model)
        :param values: Values (b_s, nk, d_model)
        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        :return:
        '''
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        # m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k)
        # m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v)

        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = self.fc_k(keys).view(b_s, nk , self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = self.fc_v(values).view(b_s, nk , self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)
        if attention_weights is not None:
            att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1)
        if attention_mask is not None:
            att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf)
        att = torch.softmax(att, -1)
        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)
        return out


class MultiHeadAttention(Module):
    '''
    Multi-head attention layer with Dropout and Layer Normalization.
    '''

    def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False,
                 attention_module=None, attention_module_kwargs=None):
        super(MultiHeadAttention, self).__init__()
        self.identity_map_reordering = identity_map_reordering
        if attention_module is not None:
            if attention_module_kwargs is not None:
                self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, **attention_module_kwargs)
            else:
                self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h)
        else:
            self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h)
        self.dropout = nn.Dropout(p=dropout)
        self.layer_norm = nn.LayerNorm(d_model)

        self.can_be_stateful = can_be_stateful
        if self.can_be_stateful:
            self.register_state('running_keys', torch.zeros((0, d_model)))
            self.register_state('running_values', torch.zeros((0, d_model)))

    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        if self.can_be_stateful and self._is_stateful:


            self.running_keys = torch.cat([self.running_keys, keys], 1)
            keys = self.running_keys


            self.running_values = torch.cat([self.running_values, values], 1)
            values = self.running_values

        if self.identity_map_reordering:
            q_norm = self.layer_norm(queries)
            k_norm = self.layer_norm(keys)
            v_norm = self.layer_norm(values)
            out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights)
            out = queries + self.dropout(torch.relu(out))
        else:
            out = self.attention(queries, keys, values, attention_mask, attention_weights)
            out = self.dropout(out)
            out = self.layer_norm(queries + out)
        return out

## `models/encoders.py` (for encoding image)

In [None]:
import math
import torch
from torch import nn
from torch.nn import functional as F


class EncoderLayer(nn.Module):
    def __init__(self, d_model=768, d_k=64, d_v=64, h=12, d_ff=2048, dropout=.1, identity_map_reordering=False,
                 attention_module=None, attention_module_kwargs=None):
        super(EncoderLayer, self).__init__()
        self.identity_map_reordering = identity_map_reordering
        self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering,
                                        attention_module=attention_module,
                                        attention_module_kwargs=attention_module_kwargs)
        self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering)

    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        att = self.mhatt(queries, keys, values, attention_mask, attention_weights)
        ff = self.pwff(att)
        return ff


class MultiLevelEncoder(nn.Module):
    def __init__(self, N, padding_idx, d_model=768, d_k=64, d_v=64, h=12, d_ff=2048, dropout=.1,
                 identity_map_reordering=False, attention_module=None, attention_module_kwargs=None):
        super(MultiLevelEncoder, self).__init__()
        self.d_model = d_model
        self.dropout = dropout
        self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout,
                                                  identity_map_reordering=identity_map_reordering,
                                                  attention_module=attention_module,
                                                  attention_module_kwargs=attention_module_kwargs)
                                     for _ in range(N)])
        self.padding_idx = padding_idx

    def forward(self, input, attention_weights=None):
        # input (b_s, seq_len, d_in)
        attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, seq_len)

        outs = []
        out = input
        for l in self.layers:
            out = l(out, out, out, attention_mask, attention_weights)
            outs.append(out.unsqueeze(1))

        outs = torch.cat(outs, 1)
        return outs, attention_mask

def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


class VisualEncoder(MultiLevelEncoder):
    def __init__(self, N, padding_idx, d_in=2048, **kwargs):
        super(VisualEncoder, self).__init__(N, padding_idx, **kwargs)
        self.fc = nn.Linear(d_in, self.d_model)
        self.dropout = nn.Dropout(p=self.dropout)
        self.layer_norm = nn.LayerNorm(self.d_model)

    def forward(self, input, attention_weights=None):
        # out = F.relu(self.fc(input))


        out = gelu(self.fc(input))
        out = self.dropout(out)
        out = self.layer_norm(out)

        return super(VisualEncoder, self).forward(out, attention_weights=attention_weights)

In [None]:
VisualEncoder(3, 0, 512)(torch.randn(1, 1, 512))[0].shape

torch.Size([1, 3, 1, 768])

## `models/transformer/config.py` (Configs for GPT2)

In [None]:
'''
    code by TaeHwan Jung(@graykode)
    Original Paper and repository here : https://github.com/openai/gpt-2
    GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT
'''
class GPT2Config(object):
    def __init__(
            self,
            vocab_size_or_config_json_file=50257,
            n_positions=1024,
            n_ctx=60,
            n_embd=768,
            n_layer=12,
            n_head=12,
            layer_norm_epsilon=1e-5,
            initializer_range=0.02,
            attn_pdrop=0.1,
            resid_pdrop = 0.1,
            


    ):
        self.vocab_size = vocab_size_or_config_json_file
        self.n_ctx = n_ctx
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range
        self.attn_pdrop = attn_pdrop
        self.resid_pdrop = resid_pdrop

## `models/transformer/load_gptmodel.py`

In [None]:
'''
    code by TaeHwan Jung(@graykode)
    Original Paper and repository here : https://github.com/openai/gpt-2
    GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT
'''
import logging

logger = logging.getLogger(__name__)

def load_weight(model, state_dict):
    old_keys = []
    new_keys = []
    for key in state_dict.keys():


        new_key = None
        if key.endswith(".g"):
            new_key = key[:-2] + ".weight"
        elif key.endswith(".b"):
            new_key = key[:-2] + ".bias"
        elif key.endswith(".w"):
            new_key = key[:-2] + ".weight"
        if new_key:
            old_keys.append(key)
            new_keys.append(new_key)
    for old_key, new_key in zip(old_keys, new_keys):
        state_dict[new_key] = state_dict.pop(old_key)

    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, "_metadata", None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=""):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
        )
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + ".")

    start_model = model
    if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
        start_model = model.transformer
    load(start_model, prefix="")

    # Make sure we are still sharing the output and input embeddings after loading weights
    # model.set_tied() <-- only needed if we are using GPT2MLMHead
    return model

## `models/transformer/gpt_decoder_visualGPT.py` (implements GPT2)

In [None]:
'''
    code by TaeHwan Jung(@graykode)
    Original Paper and repository here : https://github.com/openai/gpt-2
    GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT
'''
import copy
import torch
import math
import torch.nn as nn
from torch.nn.parameter import Parameter
import numpy as np
from torch.nn import functional as F

def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super(LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias

class Conv1D(nn.Module):
    def __init__(self, nf, nx):
        super(Conv1D, self).__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = Parameter(w)
        self.bias = Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        test = x.contiguous().view(-1, x.size(-1))
        x = torch.addmm(self.bias, x.contiguous().view(-1, x.size(-1)), self.weight)

        x = x.view(*size_out)
        return x


class Attention(Module):
    def __init__(self, nx, n_ctx, config, scale=False,can_be_stateful=False):
        super(Attention, self).__init__()

        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
        self.can_be_stateful = can_be_stateful
        self.attn_pdrop = nn.Dropout(config.attn_pdrop)

        if self.can_be_stateful:
            self.register_state('running_keys', torch.zeros((12,0, 64)))
            self.register_state('running_values', torch.zeros((12,0, 64)))


    def _attn(self, q, k, v,mask_self_attention):

        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))

        if mask_self_attention is not None:


            w = w.masked_fill(mask_self_attention, -10000.0)
            # w[:,:,:,:nk] = w[:,:,:,:nk].masked_fill(mask_self_attention, -1e7)
        # nd, ns = w.size(-2), w.size(-1)
        # b = self.bias[:, :, ns-nd:ns, :ns]

        # w = w * b - 1e10 * (1 - b)
        w = nn.Softmax(dim=-1)(w)
        self.w = self.attn_pdrop(w)
        return torch.matmul(w, v)

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
        else:
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)

    def forward(self, x, layer_past=None,mask_self_attention=None):
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)

        if self.can_be_stateful and self._is_stateful:
            self.running_keys = torch.cat([self.running_keys, key.transpose(-2,-1)],-2)
            key = self.running_keys.transpose(-2,-1)

            self.running_values = torch.cat([self.running_values, value], -2)
            value = self.running_values
        # if layer_past is not None:
        #     past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
        #     key = torch.cat((past_key, key), dim=-1)
        #     value = torch.cat((past_value, value), dim=-2)

        present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking
        a = self._attn(query, key, value,mask_self_attention)
        a = self.merge_heads(a)
        a = self.c_proj(a)


        return a, present


class Enc_Dec_Attention(Module):
    def __init__(self, nx, n_ctx, config, scale=False):
        super(Enc_Dec_Attention, self).__init__()
        n_state = nx = 768
        n_ctx = 60
        scale = True
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % 12 == 0
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.n_head = 12
        self.split_size = n_state
        self.scale = scale
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)

        self.fc_q = nn.Linear(n_state, 64 * 12)
        self.fc_k = nn.Linear(n_state, 64 * 12)
        self.fc_v = nn.Linear(n_state, 64 * 12)

        self.attn_dropout = nn.Dropout(0.2)

        self.init_weights()


    def init_weights(self):
        nn.init.xavier_uniform_(self.fc_q.weight)
        nn.init.xavier_uniform_(self.fc_k.weight)
        nn.init.xavier_uniform_(self.fc_v.weight)

        nn.init.constant_(self.fc_q.bias, 0)
        nn.init.constant_(self.fc_k.bias, 0)
        nn.init.constant_(self.fc_v.bias, 0)
        # nn.init.xavier_uniform_(self.fc_o.weight)



    def _attn(self, q, k, v,enc_dec_attention):
        nk = k.shape[-1]
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
        nd, ns = w.size(-2), w.size(-1)
        b = self.bias[:, :, ns-nd:ns, :ns]
        if enc_dec_attention is not None:
            w = w.masked_fill(enc_dec_attention, -10000.0)
            # w[:, :, ns-nd:ns, :ns] = w[:, :, ns-nd:ns, :ns].masked_fill(enc_dec_attention, -1e10)

        # w = w*enc_dec_attention

        # w = w * b - 1e10 * (1 - b)
        w = nn.Softmax(dim=-1)(w)
        w = self.attn_dropout(w)
        return torch.matmul(w, v)

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
        else:
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)

    def forward(self, x, layer_past=None,encoder_output=None, mask_encoder=None):

        query = self.fc_q(x)
        encoder_key = self.fc_k(encoder_output)
        encoder_value = self.fc_v(encoder_output)
        query = self.split_heads(query)
        encoder_key = self.split_heads(encoder_key, k=True)
        encoder_value = self.split_heads(encoder_value)


        a = self._attn(query, encoder_key,encoder_value,mask_encoder)
        a = self.merge_heads(a)
        a = self.c_proj(a)
        return a


class MLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
        super(MLP, self).__init__()
        nx = config.n_embd
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
        self.act = gelu

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        return h2

class Block(Module):
    def __init__(self, n_ctx, config, scale=False):
        super(Block, self).__init__()
        nx = config.n_embd

        self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.attn = Attention(nx, n_ctx, config, scale,can_be_stateful=True)
        self.enc_dec_attn = Enc_Dec_Attention(nx,n_ctx,config,scale)
        self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.mlp = MLP(4 * nx, config)
        self.resid_pdrop= nn.Dropout(config.resid_pdrop)



        self.fc_alpha1 = nn.Linear(nx + nx, nx)
        self.fc_alpha2 = nn.Linear(nx + nx, nx)
        self.fc_alpha3 = nn.Linear(nx + nx, nx)


    def forward(self, x, layer_past=None,mask_queries=None,encoder_output=None,mask_encoder=None, mask_self_attention=None, tau = 0):
        threshold = tau

        self_attention, present = self.attn(self.ln_1(x), layer_past=layer_past,
                                            mask_self_attention=mask_self_attention)
        a = x + self_attention
        a = self.resid_pdrop(a)


        enc_att1 = self.enc_dec_attn(x=self.ln_1(a), encoder_output=self.ln_1(encoder_output[:, 0]),mask_encoder=mask_encoder)
     
        enc_att2 = self.enc_dec_attn(x=self.ln_1(a), encoder_output=self.ln_1(encoder_output[:, 1]),mask_encoder=mask_encoder)
     
        enc_att3 = self.enc_dec_attn(x=self.ln_1(a), encoder_output=self.ln_1(encoder_output[:, 2]),mask_encoder=mask_encoder)
     

        alpha1 = torch.sigmoid(self.fc_alpha1(torch.cat([a, enc_att1], -1)))
        alpha2 = torch.sigmoid(self.fc_alpha2(torch.cat([a, enc_att2], -1)))
        alpha3 = torch.sigmoid(self.fc_alpha3(torch.cat([a, enc_att3], -1)))


        linguistics_alpha1_mask = torch.where(alpha1 > threshold, torch.ones_like(alpha1), torch.zeros_like(alpha1))
        linguistics_alpha2_mask = torch.where(alpha2 > threshold, torch.ones_like(alpha2), torch.zeros_like(alpha2))
        linguistics_alpha3_mask = torch.where(alpha3 > threshold, torch.ones_like(alpha3), torch.zeros_like(alpha3))


        visual_alpha1_mask = torch.where(alpha1 < 1-threshold, torch.ones_like(alpha1), torch.zeros_like(alpha1))
        visual_alpha2_mask = torch.where(alpha2 < 1-threshold, torch.ones_like(alpha2), torch.zeros_like(alpha2))
        visual_alpha3_mask = torch.where(alpha3 < 1-threshold, torch.ones_like(alpha3), torch.zeros_like(alpha3))



        enc_att1 = alpha1* linguistics_alpha1_mask * a + (1-alpha1)* visual_alpha1_mask * enc_att1
        enc_att2 = alpha2* linguistics_alpha2_mask * a + (1-alpha2)* visual_alpha2_mask * enc_att2
        enc_att3 = alpha3* linguistics_alpha3_mask * a + (1-alpha3)* visual_alpha3_mask* enc_att3

        enc_att = (enc_att1 + enc_att2 + enc_att3) / np.sqrt(3)
        a = enc_att * mask_queries

        m = self.mlp(self.ln_2(a))

        encoder_result = a + m

        encoder_result = self.resid_pdrop(encoder_result)

        encoder_result = encoder_result  * mask_queries
        return encoder_result, present

class GPT2Model(Module):
    def __init__(self, config):
        super(GPT2Model, self).__init__()
        self.n_layer = config.n_layer
        self.n_embd = config.n_embd
        self.n_vocab = config.vocab_size

        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        block = Block(config.n_ctx, config, scale=True)
        self.h = ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
        self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.register_state('running_seq', torch.zeros((1,)).long())


    def set_embeddings_weights(self, model_embeddings_weights):
        embed_shape = model_embeddings_weights.shape
        self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
        self.decoder.weight = model_embeddings_weights  # Tied weights

    def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None,mask_queries=None,encoder_output=None,mask_encoder=None, mask_self_attention = None, tau = 0):


        if past is None:
            past_length = 0
            past = [None] * len(self.h)
        else:
            past_length = past[0][0].size(-2)
        if position_ids is None:
            position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long,
                                        device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_ids.size(-1))
        position_ids = position_ids.view(-1, position_ids.size(-1))

        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
            token_type_embeds = self.wte(token_type_ids)
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
        presents = []


        for block, layer_past in zip(self.h, past):
            hidden_states, present = block(hidden_states, layer_past,mask_queries = mask_queries,encoder_output=encoder_output,mask_encoder=mask_encoder, mask_self_attention= mask_self_attention, tau = tau)
            presents.append(present)
        hidden_states = self.ln_f(hidden_states)
        output_shape = input_shape + (hidden_states.size(-1),)
        return hidden_states.view(*output_shape), presents

class GPT2LMHead(Module):
    def __init__(self, model_embeddings_weights, config):
        super(GPT2LMHead, self).__init__()
        self.n_embd = config.n_embd
        self.set_embeddings_weights(model_embeddings_weights)

    def set_embeddings_weights(self, model_embeddings_weights):
        embed_shape = model_embeddings_weights.shape
        self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
        self.decoder.weight = model_embeddings_weights  # Tied weights

    def forward(self, hidden_state):

        lm_logits = self.decoder(hidden_state)
        return lm_logits



class GPT2LMHeadModel(Module):
    def __init__(self, config,padding_idx =47932, tau = 0):
        super(GPT2LMHeadModel, self).__init__()
        self.transformer = GPT2Model(config)
        self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
        self.padding_idx = padding_idx

        self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).bool())
        self.tau = tau



    def set_tied(self):
        """ Make sure we are sharing the embeddings
        """
        self.lm_head.set_embeddings_weights(self.transformer.wte.weight)

    def forward(self, input_ids, encoder_output=None, mask_encoder=None, position_ids=None, token_type_ids=None, lm_labels=None, past=None):

        b_s, seq_len = input_ids.shape[:2]
        mask_queries = (input_ids != self.padding_idx).unsqueeze(-1).float()

        mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input_ids.device),
                                         diagonal=1)
        mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)
        mask_self_attention = mask_self_attention + (input_ids == self.padding_idx).unsqueeze(1).unsqueeze(1).bool()
        mask_self_attention = mask_self_attention.gt(0)  # (b_s, 1, seq_len, seq_len)
        if self._is_stateful:
            self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], -1)
            mask_self_attention = self.running_mask_self_attention



        hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past,mask_queries=mask_queries,encoder_output=encoder_output,mask_encoder=mask_encoder, mask_self_attention= mask_self_attention, tau = self.tau)
        lm_logits = self.lm_head(hidden_states)
        if lm_labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
            return loss

        lm_logits = F.log_softmax(lm_logits,dim=-1)
        return lm_logits, presents

# Our Model

In [None]:
import torch 
import torch.nn as nn
import copy

state_dict = torch.load('gpt2-pytorch_model.bin', map_location='cpu' if not torch.cuda.is_available() else None)

class Transformer_visualgpt(Module):
  def __init__(self, bos_idx, encoder, padding_idx=47932, n_layer=12, tau=0):
    super(Transformer_visualgpt, self).__init__()
    self.bos_idx = bos_idx
    self.encoder = encoder
    config = GPT2Config()
    config.n_layer = n_layer
    decoder = GPT2Model(config)
    decoder = load_weight(decoder, state_dict)
    self.decoder = decoder

    self.padding_idx = padding_idx
    self.tau = tau

    self.register_state('enc_output', None)
    self.register_state('mask_enc', None)
    self.init_weights()
  
  def init_weights(self):
    for p in self.encoder.parameters():
      if p.dim()> 1:
        nn.init.xavier_uniform_(p)

  def forward(self, input_ids, images, *args):
    enc_output, mask_enc = self.encoder(images)


    b_s, seq_len = input_ids.shape[:2]
    mask_queries = (input_ids != self.padding_idx).unsqueeze(-1).float()
    mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input_ids.device), diagonal=1)
    mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)
    mask_self_attention = mask_self_attention + (input_ids == self.padding_idx).unsqueeze(1).unsqueeze(1).bool()
    mask_self_attention = mask_self_attention.gt(0)  # (b_s, 1, seq_len, seq_len)
    if self._is_stateful:
      self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], -1)
      mask_self_attention = self.running_mask_self_attention
    
    hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past,mask_queries=mask_queries,encoder_output=encoder_output,mask_encoder=mask_encoder, mask_self_attention= mask_self_attention, tau = self.tau)
    
    return hidden_states, presents


    

In [None]:
encoder = VisualEncoder(3, 0, attention_module=ScaledDotProductAttention)
model = Transformer_visualgpt(0, encoder)

In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of trainable parameters:", pytorch_total_params)

Total number of trainable parameters: 239976960


## Setup tokenizer

In [None]:
!git clone https://github.com/openai/CLIP.git
!pip install -e CLIP/

fatal: destination path 'CLIP' already exists and is not an empty directory.
Obtaining file:///content/CLIP
Collecting ftfy
[?25l  Downloading https://files.pythonhosted.org/packages/ce/b5/5da463f9c7823e0e575e9908d004e2af4b36efa8d02d3d6dad57094fcb11/ftfy-6.0.1.tar.gz (63kB)
[K     |████████████████████████████████| 71kB 6.3MB/s 
Collecting torch~=1.7.1
[?25l  Downloading https://files.pythonhosted.org/packages/90/5d/095ddddc91c8a769a68c791c019c5793f9c4456a688ddd235d6670924ecb/torch-1.7.1-cp37-cp37m-manylinux1_x86_64.whl (776.8MB)
[K     |████████████████████████████████| 776.8MB 22kB/s 
[?25hCollecting torchvision~=0.8.2
[?25l  Downloading https://files.pythonhosted.org/packages/94/df/969e69a94cff1c8911acb0688117f95e1915becc1e01c73e7960a2c76ec8/torchvision-0.8.2-cp37-cp37m-manylinux1_x86_64.whl (12.8MB)
[K     |████████████████████████████████| 12.8MB 42.5MB/s 
Building wheels for collected packages: ftfy
  Building wheel for ftfy (setup.py) ... [?25l[?25hdone
  Created wheel 

In [None]:
!pip install transformers
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer("Hello world")

Successfully installed sacremoses-0.0.45 tokenizers-0.10.2 transformers-4.5.1


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355256.0, style=ProgressStyle(descript…




{'input_ids': [15496, 995], 'attention_mask': [1, 1]}

In [None]:
"""
Most of the source code is taken from
https://www.drivendata.co/blog/hateful-memes-benchmark/
"""
from torch.utils.data import Dataset
import pandas as pd
import torch

from PIL import Image


class HatefulMemesDataset(Dataset):

    def __init__(self, jsonl_path, path_to_img_dir,
                 image_transform, text_transform):
        """
        :param jsonl_path: path to jsonl provided by Facebook (e.g. data/train.jsonl
        :param path_to_img_dir: path to parent directory of img dir
        :param image_transform: torchvision.transforms.Compose
        :param text_transform: (texts: Union[str, List[str]]) -> torch.LongTensor
        """
        self.samples_frame = pd.read_json(jsonl_path, lines=True)
        self.samples_frame = self.samples_frame.reset_index(drop=True)
        self.samples_frame.img = self.samples_frame.apply(lambda row: (path_to_img_dir + '/' + row.img), axis=1)
        self.image_transform = image_transform
        self.text_transform = text_transform

    def __len__(self):
        return len(self.samples_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img_id = self.samples_frame.loc[idx, "id"]
        image = Image.open(self.samples_frame.loc[idx, "img"]).convert("RGB")
        image = self.image_transform(image)

        # TODO: Find a better way for reducing length of a sentence.
        # this is an actual sentence from the dataset:
        #
        # we only want to make you register them, restrict transfers, 
        # ban certain guns, limit magazine capacity, prohibit carrying them, 
        # ban or limit ammo, make other arbitrary laws, and, if we catch you 
        # violating any of these made-up rules, throw you in prison.... at 
        # which point we will take your guns!

        text = self.samples_frame.loc[idx, "text"]
        text = self.text_transform(text)
        # print(text)

        # case: development
        if "label" in self.samples_frame.columns:
            label = torch.Tensor(
                [self.samples_frame.loc[idx, "label"]]
            ).long().squeeze()
            return image, text, label
        else:
            # case: inference
            return image, text

In [None]:
"""
Simple training loop
"""
import math
import logging
from itertools import chain

from tqdm.notebook import tqdm
import numpy as np

import torch

from torch.utils.data import DistributedSampler, DataLoader

from torch.cuda.amp import autocast
from transformers import AdamW

class Trainer:
    def __init__(self, model, loss_f, image_preprocess, text_preprocess, h, ckpt_path):
        """
        :param model: torch.Module(text, image) -> 0 or 1 (binary classification)
        :param loss_f: (model's output, target) -> a real number wrapped by torch.Tensor
        :param dictionary that contains the hyper-parameter values
        """
        self.model = model.cuda()
        self.loss_f = loss_f
        self.h = h
        self.image_preprocess = image_preprocess
        self.text_preprocess = text_preprocess
        self.ckpt_path = ckpt_path

    def save_checkpoint(self):
        # DataParallel wrappers keep raw model object in .module attribute
        raw_model = self.model
        print("saving", self.ckpt_path)
        torch.save(raw_model.state_dict(), self.ckpt_path)

    def train(self, trainset_jsonl, trainset_image_dir_path, valset_jsonl, valset_image_dir_path):
        model, loss_f, h = self.model, self.loss_f, self.h
        optimizer = AdamW(model.parameters(),lr=1e-4,betas=(0.9, 0.999), eps=1e-8)

        train_dataset = HatefulMemesDataset(
            trainset_jsonl,
            trainset_image_dir_path,
            image_transform=self.image_preprocess,
            text_transform=self.text_preprocess)

        val_dataset = HatefulMemesDataset(
            valset_jsonl,
            valset_image_dir_path,
            image_transform=self.image_preprocess,
            text_transform=self.text_preprocess
        )

        self.device = 'cpu'
        if torch.cuda.is_available():
            self.device = torch.cuda.current_device()

        def run_epoch(split):
            is_train = split == 'train'
            # Custom code for VisualBERT + CLIP
            model.embedding.train()
            model.encoder.train()
            model.pooler.train()
            model.classifier.train()

            data = train_dataset if is_train else val_dataset
            loader = DataLoader(data, shuffle=True, pin_memory=True,
                                batch_size=h["batch_size"],
                                num_workers=h["num_workers"])

            losses = []
            num_correct_pred = 0
            num_pred = 0
            pbar = tqdm(enumerate(loader), total=len(loader), position=0, leave=True) if is_train else enumerate(loader)
            for it, (image, text, label) in pbar:
                text = tokenizer(text, return_tensors="pt", padding=True)
                # place data on the correct device
                text = text.to(self.device)
                image = image.to(self.device)
                label = label.type(torch.LongTensor).to(self.device)

                with torch.set_grad_enabled(is_train):
                    with autocast():
                        
                        output = model(text=text, image=image)
                        loss = loss_f(output, label)
                        losses.append(loss.item())

                if is_train:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    pbar.set_description(f"epoch {epoch + 1} iter {it}: train loss {loss.item():.5f}")

                if not is_train:
                    prediction = torch.argmax(output, dim=1)
                    num_correct_pred += torch.sum(prediction == label)
                    num_pred += prediction.shape[0]

            if not is_train:
                test_loss = float(np.mean(losses))
                acc = num_correct_pred / num_pred
                print("test loss:", test_loss)
                print("test accL", acc)
                return test_loss

        best_loss = float('inf')
        for epoch in range(h["max_epochs"]):
            run_epoch('train')
            if val_dataset is not None:
                test_loss = run_epoch('test')

            # supports early stopping based on the test loss, or just save always if no test set is provided
            good_model = val_dataset is None or test_loss < best_loss
            if self.ckpt_path is not None and good_model:
                best_loss = test_loss
                self.save_checkpoint()

In [None]:
# Thank you Mario 🙏
from google.colab import drive
drive.mount('/gdrive')

!cp '/gdrive/MyDrive/MemesDeepLearning/dataFB.zip' '/content/data.zip'
!unzip -q data.zip

Mounted at /gdrive


In [None]:
import torch.nn.functional as F 


model = model
loss_f = nn.CrossEntropyLoss()
image_preprocess = clip_image_preprocess
text_preprocess = lambda x: x
h = {
  "batch_size": 25,
  "num_workers": 2,
  "max_epochs": 32,
}


trainer = Trainer(model, loss_f, image_preprocess, text_preprocess, h, "experiment_0.pt")

RuntimeError: ignored

In [None]:
clip.encode_image(torch.randn(1, 3, 224, 224)).shape

torch.Size([1, 512])

In [None]:
# config = GPT2Config()
input_ids = tokenizer("Hello world")
# GPT2Model(config)(sample["input_ids"], encoder_output=torch.randn(1, 512))

In [None]:
input_ids = torch.Tensor(input_ids['input_ids']).unsqueeze(0)
input_ids

tensor([[15496.,   995.]])

In [None]:
b_s, seq_len = input_ids.shape[:2]
mask_queries = (input_ids != 0).unsqueeze(-1).float()
mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input_ids.device), diagonal=1)
mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)
mask_self_attention = mask_self_attention + (input_ids == 0).unsqueeze(1).unsqueeze(1).bool()
mask_self_attention = mask_self_attention.gt(0)  # (b_s, 1, seq_len, seq_len)


NameError: ignored

In [None]:
config = GPT2Config()
GPT2Model(config)(input_ids.type(torch.LongTensor), encoder_output=torch.randn(1, 512))
# tm((input_ids), None, None, None, encoder_output=torch.randn(1, 512))

RuntimeError: ignored