<a href="https://colab.research.google.com/github/FeryET/DeepLearning_CA7/blob/master/DL_CA7_Q1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
from google.colab import drive
drive.mount('/content/drive')
os.environ['KAGGLE_CONFIG_DIR'] = "/content/drive/MyDrive/kaggle"
!kaggle datasets download -d adityajn105/flickr8k
!unzip -qo "/content/flickr8k.zip"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
flickr8k.zip: Skipping, found more recently modified local copy (use --force to force download)


In [None]:
import numpy as np
from PIL import Image
import re
import string
from glob import glob
import pandas as pd
import string 
import itertools
import collections

import torch
import torchvision
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Defining Constants

In [None]:
START_TOKEN = "<SOS>"
END_TOKEN = "<EOS>"
PAD_TOKEN = "<PAD>"
# UNK_TOKEN = "<UNK>"

MAX_LENGTH = 150


# Preprocessing Pipelines

### Captions

In [None]:


def preprocess_text(text):
  prep = text.lower()
  prep = re.sub("\s+", " ", prep)
  prep = prep.translate(string.punctuation)
  prep = f"{START_TOKEN} {prep.strip()} {END_TOKEN}"
  return prep.split()


class VocabTransform:
  def __init__(self, vocab):
    self.vocab = vocab
  
  def __call__(self, tokenized):
    return [self.vocab[t] for t in tokenized]

### Images

In [None]:
# This is copied from https://pytorch.org/hub/pytorch_vision_resnet/
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Defining the Dataset

In [None]:
class FilckrDataset(Dataset):
  def __init__(self, 
               images_path, 
               csv_path, 
               image_transforms, 
               caption_transforms, 
               vocab_thresh=5):
    self.image_transforms = image_transforms
    self.caption_transforms = caption_transforms
    self.images_path = images_path
    self.images_fnames = sorted(os.listdir(images_path))
    self.csv_path = csv_path
    self.df = pd.read_csv(csv_path)
    self._preprocess_captions(vocab_thresh)
  
  def _preprocess_captions(self, vocab_thresh):
    self.df["cleaned"] = self.df["caption"].apply(preprocess_text)
    self.df.sort_values(by="image", inplace=True)
    counter = collections.Counter(itertools.chain(self.df["cleaned"]))
    vocab = set(
        [k for k, v in counter.items() if v > vocab_thresh]
    )
    vocab.add(PAD_TOKEN)
    vocab = sorted(vocab)
    self.vocab = {v: idx for idx , v in enumerate(vocab)}
  
  @property
  def vocab_size(self):
    return len(self.vocab)

  def __getitem__(self, idx):
    self.step = 5
    idx *= step # num of repeats
    fname = self.images_fnames[idx]
    fpath = os.path.join(self.image_path, fname)
    
    image = Image.open(fpath)
    captions = self.df.iloc[idx:idx+step]["cleaned"]
    # Truncating
    captions = [c[:MAX_LENGTH] for c in captions]

    image = self.image_transforms(image)
    captions = self.caption_transforms(captions)

    return {
        "image": image, "captions": captions
    }

In [None]:
# This is needed for batches

class RepeatImages:
  def __init__(self, num_repeat):
    self.num_repeat = num_repeat

  def __call__(self, data):
    images = data["image"]
    images = torch.repeat_interleave(images, self.num_repeat, dim=0)
    data["image"] = images
    return data

class PadCaptions:
  def __init__(self, vocab):
    self.pad_idx = vocab[PAD_TOKEN]
  
  def __call__(self, data):
    captions = data["captions"]
    torch.nn.utils.rnn.pad_sequence(captions,
                                    batch_first=True,
                                    padding_vale=self.pad_idx)
    data["captions"] = captions
    return data

In [None]:

# Influenced by https://github.com/siddsrivastava/Image-captioning/blob/master/model.py

class FlickrEncoderCNN:
  def __init__(self, 
               embedding_dim, 
               do_freeze):
    super().__init__()
    # Loading resnet
    resnet = torchvision.models.resnet18(pretrained=True)
    fc_in_features = resnet.fc.in_features

    # Defining layers
    modules = list(resnet.children())[:-1]
    self.resnet = nn.Sequential(*modules)
    self.fc = nn.Linear(fc_in_features, embedding_dim)

    # Freezing if needed
    if do _freeze:
      for param in self.resnet.parameters():
        param.requires_grad = True
    for param in self.fc.parameters():
      param.requires_grad = True
  
  def forward(self, x):
    x = self.resnet(x)
    x = x.view(x.shape[0], -1)
    x = self.fc(x)
    return x


class FlickrDecoderRNN(nn.Module):
  def __init__(self, 
                vocab_size,
                embedding_dim,
                hidden_size,
                padding_idx,
                bidirectional=False,
                dropout=0):
    
    super().__init__()
    # Creating embeddings
    self.embed = nn.Embedding(num_embeddings=len_vocab,
                              embedding_dim=embedding_dim, 
                              padding_idx=padding_idx),
                              scale_grad_by_freq=True,
                              sparse=True,
                              )
    self.lstm = nn.LSTM(input_size=embedding_dim,
                        hidden_size=hidden_size,
                        batch_first=True,
                        dropout=dropout
                        bidirectional=bidirectional)
    self.fc = nn.Linear(hidden_size,vocab_size)

  def forward(self, features, caption_seqs):
      caption_seqs = caption_seqs[:,:-1] 
      embeddings = self.embed(caption_seqs)
      total_input = torch.cat((features.unsqueeze(1), embeddings), 1)
      lstm_out, self.hidden = self.lstm(total_input)
      outputs = self.fc(lstm_out)
      return outputs


class FlickrRNN(nn.Module):
  def __init__(self, 
               vocab_size,
               embedding_dim,
               hidden_size,
               padding_idx,
               bidirectional=False,
               dropout=0,
               freeze=True):
    super().__init__()
    self.encoder = FlickrEncoderCNN(embedding_dim, freeze)
    self.decoder = FlickrDecoderRNN(vocab_size, 
                                    embedding_dim, 
                                    hidden_size, 
                                    padding_idx, 
                                    bidirectional, 
                                    dropout)
  def forward(self, images, caption_seqs):
    features = self.encoder(images)
    outputs = self.decoder(features, caption_seqs)
    outputs = nn.LogSoftmax(outputs)
    return outputs
  
  def predict(self, images, states, max_len = 30):
    

    




In [None]:
counter = collections.Counter(list("1123411233000009999"))
counter.keys()

dict_keys(['1', '2', '3', '4', '0', '9'])