In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
file_loc = '/kaggle/input/clip-emb-coco-2017/final_embeddings.pkl'

with open(file_loc, 'rb') as f:
    data = pickle.load(f)

clip_embeddings = data["clip_embedding"]  # The concatenated image embeddings
captions = data["captions"]  # The list of captions

print(captions[:5])
print(len(clip_embeddings))

['an old man checking his cell phone by a building ', 'A simple bathroom with black and orange towels.', 'A kitchen with a wooden center island under two lights.', 'A large, spacious kitchen with an island in the middle.', 'A center island sitting in the middle of a kitchen.']
391753


In [13]:
from dataclasses import dataclass

@dataclass
class Config():
    epochs: int = 2
    batch_size: int = 32
    lr: float = 1e-4
    warmup_steps: int = 3000
    n_layers: int = 6
    n_clip_emb: int = 512
    n_heads: int = 8
    d_model: int = 768
    d_ff: int = 3072
    dropout: float = 0.1
    prefix_length: int = 20
    clip_length: int = 10
    dropout: float = 0.1

In [5]:
import torch
from torch.utils.data import Dataset
from transformers import GPT2Tokenizer
import pickle

class Caption_Dataset(Dataset):

    def __init__(self, file_path, prefix_length, extract_from_file = False):
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')      # gpt2 tokenizer from HF
        self.prefix_length = prefix_length
        
        # clip embedding & captions file 
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
        
        self.prefixes = data["clip_embedding"]
        self.captions = data["captions"]

        self.captions_tokens = []        # to store list of tokenised captions
        max_seq_len = 0

        if extract_from_file:          # if tokenisation is already done, load from file
            with open('/kaggle/working/caption_tokens.pkl', 'rb') as f:
                self.captions_tokens, self.max_seq_len = pickle.load(f)
        else:
            for caption in self.captions:
                self.captions_tokens.append(torch.tensor(self.tokenizer.encode(caption), dtype=torch.int64))           # storing tokenised captions
                max_seq_len = max(max_seq_len, self.captions_tokens[-1].shape[0])    

            with open("/kaggle/working/caption_tokens.pkl", 'wb') as f:
                pickle.dump([self.captions_tokens, max_seq_len], f)
            
            self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.captions)
    
    def pad_mask(self,idx):

        tokens = self.captions_tokens[idx]
        padding = self.max_seq_len - tokens.shape[0]
        if padding > 0:
            tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))          # padding with -1 tokens
            self.captions_tokens[idx] = tokens
        elif padding < 0:
            tokens = tokens[:self.max_seq_len]          # trimming the tokens
            self.captions_tokens[idx] = tokens

        mask = tokens.ge(0)              # assigns false to -1 tokens  
        tokens[~mask] = 0                # assigns 0 to false (-1) 
        mask = mask.float()
        # prefix should always be considered in attention mechanism
        mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0)     # adding prefix length to mask
        
        return tokens, mask

    def __getitem__(self, idx):
        tokens, mask = self.pad_mask(idx)
        return tokens, mask, self.prefixes[idx]

In [6]:
class FeedForward(nn.Module):
    '''Feed Forward Network for Transformer'''

    def __init__(self, d_model, d_ff, dropout_ratio = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout_ratio)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x


class MultiHeadSA(nn.Module):
    '''Multi Head Self Attention Layer'''

    def __init__(self, n_heads, d_model, input_dim):  
        super().__init__()    
        assert d_model % n_heads == 0 , "Invalid head_size for the given d_model"
        self.n_heads = n_heads
        self.d_model = d_model
        self.head_size = d_model // n_heads
        self.input_dim = input_dim
        self.qkv_proj = nn.Linear(input_dim, 3 * d_model)
        self.linear = nn.Linear(d_model, d_model)
    
    def forward(self, X, mask = None):

        B, T, C = X.shape
        assert C == self.input_dim, "Input dimension does not match the model input dimension"
        qkv = self.qkv_proj(X)                                    # (B,T,3*D)
        qkv = qkv.reshape(B, T, self.n_heads, 3 * self.d_model // self.n_heads)
        qkv = qkv.permute(0,2,1,3)
        q, k, v = torch.chunk(qkv, 3, dim=-1)
        if mask is None:
            attention_score = torch.softmax(q @ k.transpose(-2, -1) / (self.head_size ** 0.5), dim=-1)
        else:
            mask = mask.unsqueeze(1)  # for broadcasting
            attention_score = torch.softmax(q @ k.transpose(-2, -1) / (self.head_size ** 0.5) + mask, dim=-1)
        res = attention_score @ v                                       # (B,H,T,head_size)
        res = res.permute(0,2,1,3).reshape(B, T, self.d_model)   
        res = self.linear(res)

        return res               

class EncoderLayer(nn.Module):
    '''Single Layer of Transformer Encoder'''
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.multi_head_sa = MultiHeadSA(self.config.n_heads, self.config.d_model, self.config.d_model)
        self.feed_forward = FeedForward(self.config.d_model, self.config.d_ff, self.config.dropout)
        self.norm1 = nn.LayerNorm(self.config.d_model)
        self.norm2 = nn.LayerNorm(self.config.d_model)
        self.dropout = nn.Dropout(self.config.dropout)

    def forward(self, x, mask = None):
        # ordering of layernorm is like the GPT2 paper and not like the original transformer paper
        # layer norm before attention and feed forward
        res = self.norm1(x)
        res = self.multi_head_sa(res, mask)
        res = x + self.dropout(res)          # residual connection and dropout
        res = self.norm2(res)
        res2 = self.feed_forward(res)
        res = res + self.dropout(res2)

        return res

class Transformer(nn.Module):
    '''Modified Encoder Only Representation of Transformer
        for converting CLIP embedding to GPT2 input'''
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.encoder_layers = nn.ModuleList([EncoderLayer(config) for _ in range(self.config.n_layers)])

    def forward(self, x, mask = None):
        for layer in self.encoder_layers:
            x = layer(x, mask)
        
        return x
    

In [7]:
class Mapping_Network(nn.Module):

    def __init__(self, config):

        super().__init__()
        self.config = config
        self.linear = nn.Linear(config.n_clip_emb, config.clip_length * config.d_model)        # 512 -> clip_length * d_model(768)
        self.fixed_prefix = nn.Parameter(torch.randn(config.prefix_length, config.d_model), requires_grad=True)            # fixed prefix
        self.transformer = Transformer(config)


    def forward(self, x):
        # x: (batch_size, n_clip_emb)
        res = self.linear(x)          # (batch_size, clip_length * d_model)
        res = res.view(res.shape[0], self.config.clip_length, self.config.d_model)        # (batch_size, clip_length, d_model)
        prefix = self.fixed_prefix.unsqueeze(0)             # adding batch dimension
        prefix = prefix.repeat(res.shape[0], 1, 1)          # (batch_size, prefix_length, d_model)
        # first clip_embedding followed by fixed prefix
        res = torch.cat((res, prefix), dim=1)               # (batch_size, prefix_length + clip_length, d_model)
        res = self.transformer(res)                         
        
        return res[:,self.config.clip_length:]       


In [8]:
from transformers import GPT2LMHeadModel

class CaptionModel(nn.Module):

    def __init__(self,config):
        super().__init__()
        self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        self.config = config
        self.clip_embedding_mapping = Mapping_Network(config)

    def forward(self, tokens, prefix, mask):
        cap_emb = self.gpt.transformer.wte(tokens)          # (batch_size, seq_len, embedding_size)
        clip_emb = self.clip_embedding_mapping(prefix).view(-1,self.config.prefix_length,self.gpt_embedding_size)      # (batch_size, prefix_length, d_model_gpt2)
        res = torch.cat((clip_emb, cap_emb), dim=1)
        res = self.gpt(inputs_embeds = res, attention_mask = mask, return_dict = True)

        return res.logits
    
    def train(self, mode = True):
        super(CaptionModel, self).train(mode)
        # freeze and train
        self.gpt.eval()                      # gpt2 weights remain fixed
        return self

In [14]:
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm 

config = Config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
file_path = '/kaggle/input/clip-emb-coco-2017/final_embeddings.pkl'

model = CaptionModel(config)
model = nn.DataParallel(model)         # Use multiple GPUs
# Move model to GPU
model = model.to(device)

model.train()

dataset = Caption_Dataset(file_path, prefix_length = config.prefix_length, extract_from_file = True)
train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
optimizer = AdamW(model.parameters(), lr=config.lr)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=config.epochs * len(train_loader))

In [10]:
len(train_loader)

12243

In [11]:
loss_hist = []

In [16]:
model.load_state_dict(torch.load('/kaggle/input/caption_till_epch2/pytorch/default/1/model_epoch_2.pt'))

  model.load_state_dict(torch.load('/kaggle/input/caption_till_epch2/pytorch/default/1/model_epoch_2.pt'))


<All keys matched successfully>

In [None]:
for epoch in range(2,4):
    
        print(f"Epoch: {epoch}")
        progress = tqdm(total=len(train_loader), desc="Training", leave = False)
    
        for idx, (tokens, mask, prefix) in enumerate(train_loader):

            model.zero_grad()
            tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
            outputs = model(tokens, prefix, mask)      # (B, prefix_length + max_seq_length, d_model)
            logits = outputs[:, config.prefix_length - 1:-1]             # (B, max_seq_length, vocab_size)
            # only consider the output for caption tokens

            # calculating loss
            loss = F.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=0)    # ignore_index to ignore loss for padding tokens
            loss.backward()
            optimizer.step()       # gradient descent
            scheduler.step()       # lr decay

            optimizer.zero_grad()
            
            # updating progress bar
            progress.set_postfix(loss=loss.item())
            progress.update()

            loss_hist.append(loss.item())

            if (idx+1)%3000==0:
                torch.save(model.state_dict(), f'model_epoch:{epoch+1}:{idx+1}.pt')
                with open(f'loss_hist:{epoch+1}:{idx+1}.pkl', 'wb') as f:
                    pickle.dump(loss_hist, f)
                loss_hist = []

        progress.close()
        
        torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pt")

Epoch: 2


Training:  43%|████▎     | 5210/12243 [1:53:46<2:33:43,  1.31s/it, loss=1.76]

In [None]:
print("test")