In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import datasets, transforms as T
from PIL import Image

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from collections import OrderedDict

In [3]:
# TODO: Try pre trained CLIP

In [4]:
from torchvision.datasets.utils import download_and_extract_archive

In [5]:
# download_and_extract_archive("http://images.cocodataset.org/zips/train2017.zip",
#                              download_root="../datasets/COCO",
#                              remove_finished=True)

In [6]:
# download_and_extract_archive("http://images.cocodataset.org/zips/val2017.zip",
#                              download_root="../datasets/COCO",
#                              remove_finished=True)

In [7]:
# download_and_extract_archive("http://images.cocodataset.org/annotations/annotations_trainval2017.zip",
#                              download_root="../datasets/COCO",
#                              remove_finished=True)

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
input_size = 224

In [10]:
preproc = {
    'train': T.Compose([
        T.RandomResizedCrop(input_size),
        T.RandomHorizontalFlip(input_size),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
    'val': T.Compose([
        T.Resize(input_size),
        T.CenterCrop(input_size),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [30]:
backbone = models.resnext50_32x4d(pretrained=False)

In [12]:
backbone.fc.in_features

2048

In [15]:
class ProjectionHead(nn.Module):
    def __init__(self, d_model=768, dp_rate=0.1):
        super().__init__()
        self.proj = nn.LazyLinear(d_model)
        self.activation = nn.GELU()
        self.dropout1 = nn.Dropout(dp_rate)
        self.lin = nn.LazyLinear(d_model)
        self.dropout2 = nn.Dropout(dp_rate)
        self.ln = nn.LayerNorm(d_model)

    def forward(self, inp):
        x = inp
        x = self.dropout1(self.activation(self.proj(x)))
        x = self.dropout2(self.lin(x))
        x = self.ln(x + inp)
        return x

In [16]:
inp = torch.randn(1,3,224,224)

In [39]:
backbone.avgpool = nn.Identity() # [1, 2048, 7, 7] -> 
# flatten(x, 1)
backbone.fc = nn.Identity()

In [40]:
def ap_hook(module, inp, out):
    print(inp[0].shape)
    print(out.shape)
handle = backbone.avgpool.register_forward_hook(ap_hook)
handle2 = backbone.fc.register_forward_hook(ap_hook)

In [41]:
out = backbone(inp)

torch.Size([1, 2048, 7, 7])
torch.Size([1, 2048, 7, 7])
torch.Size([1, 100352])
torch.Size([1, 100352])


In [35]:
out.shape

torch.Size([1, 1000])

In [21]:
out = out.view([1, 2048, 7, 7])

In [24]:
torch.flatten(out, 2).shape

torch.Size([1, 2048, 49])

In [36]:
handle.remove()
handle2.remove()

In [11]:
class CaptionModel(nn.Module):
    def __init__(self, projection_head, tgt_vocab_size, num_decoder_layers=6, nhead=8, d_model=768,
                 dim_feedforward=2048, dropout=0.1, activation='relu'):
        super().__init__()
        self.backbone = models.resnext50_32x4d(pretrained=False)
        self.freeze_weights(self.backbone)
        self.backbone.avgpool = nn.Identity()     # 
        proj_in_dim = self.backbone.fc.in_features
        self.backbone.fc = projection_head
        
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                activation)
        decoder_norm = LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
        self.generator = nn.Linear(d_model, tgt_vocab_size)
    
    def freeze_weights(self, model):
        for param in model.parameters():
            param.requires_grad = False
    
    def forward(self, x, tgt):
        x = self.backbone(x)
        # (B, seq_len, d_model) cause batch_first.
        x = x.permute(1,0)
        # (seq_len, B, d_model)
        x = x + self.pos_emb
        x = self.decoder(tgt, memory=x, tgt_mask=tgt_mask,
                         tgt_key_padding_mask=tgt_key_padding_mask)
        x = self.generator(x)
        return x