In [1]:
import pandas as pd
from transformers import DistilBertModel
from transformers import DistilBertTokenizer
from torchvision import transforms

import torch
from torch import nn
from torchvision import models
from typing import Dict, List
import cv2
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
model = models.resnet50(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/djankows/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:29<00:00, 3.46MB/s]


In [6]:
def slice_model(original_model, from_layer=None, to_layer=None):
    return nn.Sequential(*list(original_model.children())[from_layer:to_layer])

In [7]:
slice_model(model, to_layer=-1)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [16]:
CAPTIONS_PATH = "../../data/captions.csv"
IMAGES_PATH = "../../data/Images/"

In [7]:
df = pd.read_csv(CAPTIONS_PATH)

In [12]:
class TextEncoder(nn.Module):
    def __init__(self):
        super(TextEncoder, self).__init__()
        self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')

    def forward(self, input_ids, attention_mask) -> torch.Tensor:
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output[0]
        embedding = hidden_state[:, 0]
        return embedding


In [14]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

Downloading: 100%|██████████| 232k/232k [00:00<00:00, 577kB/s] 
Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 15.9kB/s]
Downloading: 100%|██████████| 483/483 [00:00<00:00, 85.0kB/s]


In [29]:
text = "Not a funny project"
tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length = 128,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True
        )


{'input_ids': [101, 2025, 1037, 6057, 2622, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [57]:
class EmbeddingDataset(Dataset):
    def __init__(self, captions: pd.DataFrame, tokenizer: DistilBertTokenizer, transform: transforms.Compose) -> None:
        """Class for processing input texts and images. Transform both into proper form and torch tensors.

        Args:
            captions (pd.DataFrame): Contains two columns: image filename and description of it
            tokenizer (DistilBertTokenizer): Tokenizer
        """
        self.images = captions['image'].tolist()
        self.captions = captions['caption'].tolist()
        self.tokenizer = tokenizer
        self.transform = transform

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        caption: str = self.captions[idx]
        tokens_caption: Dict[str, List[int]] = tokenizer.encode_plus(
            caption,
            None,
            add_special_tokens=True,
            max_length = 128,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True
        )

        image = cv2.imread(IMAGES_PATH+self.images[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            image = self.transform(image)

        ids = tokens_caption['input_ids']
        mask = tokens_caption['attention_mask']
        token_type_ids = tokens_caption['token_type_ids']

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'image': image
        }        

    def __len__(self):
        return len(self.captions)


In [66]:
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224,224)),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

In [67]:
dataset = EmbeddingDataset(df, tokenizer=tokenizer, transform=transform)

In [68]:
loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0)

In [69]:
for data in loader:
    print(loader)

<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.utils.data.dataloader.DataLoader object at 0x7fd18e141d50>
<torch.uti

KeyboardInterrupt: 