In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import skimage.io as io
import PIL.Image

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class FeedForward(nn.Module):
    '''Feed Forward Network for Transformer'''

    def __init__(self, d_model, d_ff, dropout_ratio = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout_ratio)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x


class MultiHeadSA(nn.Module):
    '''Multi Head Self Attention Layer'''

    def __init__(self, n_heads, d_model, input_dim):  
        super().__init__()    
        assert d_model % n_heads == 0 , "Invalid head_size for the given d_model"
        self.n_heads = n_heads
        self.d_model = d_model
        self.head_size = d_model // n_heads
        self.input_dim = input_dim
        self.qkv_proj = nn.Linear(input_dim, 3 * d_model)
        self.linear = nn.Linear(d_model, d_model)
    
    def forward(self, X, mask = None):

        B, T, C = X.shape
        assert C == self.input_dim, "Input dimension does not match the model input dimension"
        qkv = self.qkv_proj(X)                                    # (B,T,3*D)
        qkv = qkv.reshape(B, T, self.n_heads, 3 * self.d_model // self.n_heads)
        qkv = qkv.permute(0,2,1,3)
        q, k, v = torch.chunk(qkv, 3, dim=-1)
        if mask is None:
            attention_score = torch.softmax(q @ k.transpose(-2, -1) / (self.head_size ** 0.5), dim=-1)
        else:
            mask = mask.unsqueeze(1)  # for broadcasting
            attention_score = torch.softmax(q @ k.transpose(-2, -1) / (self.head_size ** 0.5) + mask, dim=-1)
        res = attention_score @ v                                       # (B,H,T,head_size)
        res = res.permute(0,2,1,3).reshape(B, T, self.d_model)   
        res = self.linear(res)

        return res               

class EncoderLayer(nn.Module):
    '''Single Layer of Transformer Encoder'''
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.multi_head_sa = MultiHeadSA(self.config.n_heads, self.config.d_model, self.config.d_model)
        self.feed_forward = FeedForward(self.config.d_model, self.config.d_ff, self.config.dropout)
        self.norm1 = nn.LayerNorm(self.config.d_model)
        self.norm2 = nn.LayerNorm(self.config.d_model)
        self.dropout = nn.Dropout(self.config.dropout)

    def forward(self, x, mask = None):
        # ordering of layernorm is like the GPT2 paper and not like the original transformer paper
        # layer norm before attention and feed forward
        res = self.norm1(x)
        res = self.multi_head_sa(res, mask)
        res = x + self.dropout(res)          # residual connection and dropout
        res = self.norm2(res)
        res2 = self.feed_forward(res)
        res = res + self.dropout(res2)

        return res

class Transformer(nn.Module):
    '''Modified Encoder Only Representation of Transformer
        for converting CLIP embedding to GPT2 input'''
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.encoder_layers = nn.ModuleList([EncoderLayer(config) for _ in range(self.config.n_layers)])

    def forward(self, x, mask = None):
        for layer in self.encoder_layers:
            x = layer(x, mask)
        
        return x
    

class Mapping_Network(nn.Module):

    def __init__(self, config):

        super().__init__()
        self.config = config
        self.linear = nn.Linear(config.n_clip_emb, config.clip_length * config.d_model)        # 512 -> clip_length * d_model(768)
        self.fixed_prefix = nn.Parameter(torch.randn(config.prefix_length, config.d_model), requires_grad=True)            # fixed prefix
        self.transformer = Transformer(config)


    def forward(self, x):
        # x: (batch_size, n_clip_emb)
        res = self.linear(x)          # (batch_size, clip_length * d_model)
        res = res.view(res.shape[0], self.config.clip_length, self.config.d_model)        # (batch_size, clip_length, d_model)
        prefix = self.fixed_prefix.unsqueeze(0)             # adding batch dimension
        prefix = prefix.repeat(res.shape[0], 1, 1)          # (batch_size, prefix_length, d_model)
        # first clip_embedding followed by fixed prefix
        res = torch.cat((res, prefix), dim=1)               # (batch_size, prefix_length + clip_length, d_model)
        res = self.transformer(res)                         
        
        return res[:,self.config.clip_length:]       


class CaptionModel(nn.Module):

    def __init__(self,config):
        super().__init__()
        self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        self.config = config
        self.clip_embedding_mapping = Mapping_Network(config)

    def forward(self, tokens, prefix, mask):
        cap_emb = self.gpt.transformer.wte(tokens)          # (batch_size, seq_len, embedding_size)
        clip_emb = self.clip_embedding_mapping(prefix).view(-1,self.config.prefix_length,self.gpt_embedding_size)      # (batch_size, prefix_length, d_model_gpt2)
        res = torch.cat((clip_emb, cap_emb), dim=1)
        res = self.gpt(inputs_embeds = res, attention_mask = mask)

        return res
    
    def train(self, mode = True):
        super(CaptionModel, self).train(mode)
        # freeze and train
        self.gpt.eval()                      # gpt2 weights remain fixed
        return self

In [13]:
from transformers import CLIPProcessor, CLIPModel

In [14]:
class Predictor():
    
    def __init__(self, config, path):
        
        self.device = torch.device("cpu")
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
        self.preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", return_tensor='pt')
        self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        self.config = config
        self.model = CaptionModel(self.config)
        trained_state_dict = torch.load(path, map_location= torch.device('cpu'),  weights_only=True)
        updated_state_dict = {k.replace("module.", ""): v for k, v in trained_state_dict.items()}
        self.model.load_state_dict(updated_state_dict)
        self.model = self.model.eval()
        self.model = self.model.to(self.device)

    
    def predict(self, image):
        
        image = io.imread(image)
        pil_img = PIL.Image.fromarray(image)
        processed_img = self.preprocess(images = pil_img, return_tensors='pt', padding=True)
        image_tensor = processed_img['pixel_values']
        image = image_tensor.to(self.device)
        
        with torch.no_grad():
            prefix = self.clip_model.get_image_features(pixel_values = image).to(self.device, dtype=torch.float32)
            prefix_embed = self.model.clip_embedding_mapping(prefix).reshape(1, self.config.prefix_length, -1)
        
        
        return generate_beam(self.model, self.tokenizer, embed=prefix_embed)


In [15]:
import numpy as np

In [16]:

    def generate_beam(model,tokenizer,embed, beam_size = 5, temperature = 1.0):

        model.eval()
        stop_token = '.'
        stop_token_index = tokenizer.encode(stop_token)[0]
        filter_value = -float("Inf")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        entry_length = 75
        tokens = None
        generated_list = []
        scores = None           # 1d tensor containing score of every prediction
        seq_lengths = torch.ones(beam_size, device=device)             # contains seq len of each prediction
        is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)      # flags if stopped
        

        with torch.no_grad():
            
            generated = embed         # (1,prefix_length=20,d_model=768)
            
            for i in range(entry_length):

                outputs = model.gpt(inputs_embeds=generated)
                logits = outputs.logits
                logits = logits[:, -1, :] 
                probs = F.softmax(logits, dim=-1)
                out = probs.log()             # (1,vocab_size)-> initially , later (beam_size, vocab)

                if scores is None:
                    scores, next_tokens = out.topk(beam_size, -1)      # scores (1,beam_size): out , next_tokens(1,beam_size): indices
                    generated = generated.expand(beam_size, generated.shape[1], generated.shape[2])
                    next_tokens = next_tokens.permute(1, 0)        # making B predictions with top-b tokens
                    scores = scores.squeeze(0)                     # removing batch-dim
                    if tokens is None:
                        tokens = next_tokens
                    else:
                        print("wtf")
                else:
                    logits[is_stopped] = -float(np.inf)             # marks all values as -inf for every beam which is stopped
                    logits[is_stopped, 0] = 0                       # initial element of every stopped search as 0
                    # scores[:,None] is same as scores.reshape(beam_size,1) but it doesn't create extra memory
                    scores_sum = scores[:, None] + logits           # scores is reshaped from 1d tensor to 2d for broadcasted addition
                    seq_lengths[~is_stopped] += 1
                    scores_sum_average = scores_sum / seq_lengths[:, None]           # avg: better judging parameter
                    scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)      # re-evalutaing and finding top k (beam-size) logits
                    
                    next_tokens_source = next_tokens // scores_sum.shape[1]                                # indices are flattened, mapping to corresponding beam
                    seq_lengths = seq_lengths[next_tokens_source]                                          # flattened index, mapping back to original index to tokens
                    next_tokens = next_tokens % scores_sum.shape[1]
                    next_tokens = next_tokens.unsqueeze(1)
                    tokens = tokens[next_tokens_source]
                    tokens = torch.cat((tokens, next_tokens), dim=1)
                    generated = generated[next_tokens_source]
                    scores = scores_sum_average * seq_lengths
                    is_stopped = is_stopped[next_tokens_source]
                
                
                next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
                generated = torch.cat((generated, next_token_embed), dim=1)
                is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
                
                if is_stopped.all():
                    break
                    
        scores = scores / seq_lengths
        output_list = tokens.cpu().numpy()
        output_texts = [
            tokenizer.decode(output[: int(length)])
            for output, length in zip(output_list, seq_lengths)
        ]
        
        order = scores.argsort(descending=True)
        output_texts = [output_texts[i] for i in order]
        
        return output_texts
                

In [7]:

    def generate(model,tokenizer,embed):

        model.eval()
        stop_token = '.'
        stop_token_index = tokenizer.encode(stop_token)[0]
        filter_value = -float("Inf")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        entry_length = 75
        tokens = None
        generated_list = []
        

        with torch.no_grad():
            
            generated = embed
            
            for i in range(entry_length):

                outputs = model.gpt(inputs_embeds=generated)
                logits = outputs.logits
                logits = logits[:, -1, :]
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs,1)
                next_token_embed = model.gpt.transformer.wte(next_token)
                if tokens is None:
                    tokens = next_token
                else:
                    tokens = torch.cat((tokens, next_token), dim=1)
                generated = torch.cat((generated, next_token_embed), dim=1)
                if stop_token_index == next_token.item():
                    break

            output_list = list(tokens.squeeze().cpu().numpy())
            output_text = tokenizer.decode(output_list)
            generated_list.append(output_text)

        
        return generated_list[0]

In [17]:
from dataclasses import dataclass

@dataclass
class Config():
    epochs: int = 4
    batch_size: int = 128
    lr: float = 3e-5
    warmup_steps: int = 3000
    n_layers: int = 6
    n_clip_emb: int = 512
    n_heads: int = 8
    d_model: int = 768
    d_ff: int = 3072
    dropout: float = 0.1
    prefix_length: int = 20
    clip_length: int = 10
    dropout: float = 0.1

In [18]:
path = '/kaggle/input/caption_model_epoch4/pytorch/default/1/model_epoch_4.pt'
config = Config()

predictor = Predictor(config,path)

config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]



tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [19]:
image_path = '/kaggle/input/test-image/test4.jpg'

res = predictor.predict(image_path)

In [20]:
res

['A young person sitting in front with a laptop.',
 'A young person sitting in front of  a desk with a computer on top.',
 'A young person sitting in front of  a desk with computer.',
 'A young person sitting in front of  a desk with a computer.',
 'A young person sitting in front of  a desk with a computer on top of it.']