<a href="https://colab.research.google.com/github/FeryET/DeepLearning_CA7/blob/master/DL_CA7_Q1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
from google.colab import drive
drive.mount('/content/drive')
os.environ['KAGGLE_CONFIG_DIR'] = "/content/drive/MyDrive/kaggle"
!kaggle datasets download -d adityajn105/flickr8k
!unzip -qo "/content/flickr8k.zip"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
flickr8k.zip: Skipping, found more recently modified local copy (use --force to force download)


In [2]:
import numpy as np
from PIL import Image
import re
import string
from glob import glob
import pandas as pd
import string 
import itertools
import collections

from tqdm.auto import tqdm


import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms



# Defining Constants

In [3]:
# Locations
IMAGE_ROOT = "/content/Images"
CAPTION_CSV_LOC = "/content/captions.txt"


# Tokens
START_TOKEN = "<SOS>"
END_TOKEN = "<EOS>"
PAD_TOKEN = "<PAD>"
# UNK_TOKEN = "<UNK>"

# Captions related
MAX_LENGTH = 256 # LSTM units
MIN_WORD_FREQ = 15
EMBEDDING_DIM = 100
CAPTIONS_PER_IMAGE = 5


# Dataset specifice
TEST_SPLIT = 0.1

# Processing Documents

### Preprocessing Captions

In [4]:
def preprocess_text(text):
  prep = text.lower()
  prep = re.sub("\s+", " ", prep)
  prep = prep.translate(str.maketrans('', '', string.punctuation))
  return prep.split()


caption_df = pd.read_csv(CAPTION_CSV_LOC)
caption_df.sort_values(by="image", inplace=True)
caption_df["cleaned"] = caption_df["caption"].apply(preprocess_text)

### Creating Vocabulary

In [5]:
counter = collections.Counter(itertools.chain(*caption_df["cleaned"]))
words = sorted([v for v, n in counter.items() if n > MIN_WORD_FREQ])
words += [PAD_TOKEN, START_TOKEN, END_TOKEN] + words
vocab = {w: idx for idx, w in enumerate(words)}
id2word = {idx: w for idx, w in enumerate(words)}

### Defining Transforms for Captions

In [6]:
class VocabTransform:
  def __init__(self, vocab):
    self.vocab = vocab
  
  def __call__(self, tokenized):
    return [self.vocab[t] for t in tokenized if t in self.vocab.keys()]

class CaptionConditioner:
  def __init__(self, max_length=MAX_LENGTH):
    self.max_length = max_length
  
  def __call__(self, tokenized):
    return [START_TOKEN] + tokenized[:self.max_length - 2] + [END_TOKEN]

class TextIndicesToTensor:
  def __call__(self, item):
    return torch.tensor(tuple(item), dtype=torch.int64)



caption_transforms = transforms.Compose(
    [
     CaptionConditioner(),
     VocabTransform(vocab),
     TextIndicesToTensor(),
    ]
)

### Images

In [7]:
# This is copied from https://pytorch.org/hub/pytorch_vision_resnet/
image_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224), # As an augmenting process
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Defining the Dataset

In [8]:
class FlickrDataset(Dataset):
  def __init__(self, 
               images_path, 
               caption_df, 
               image_transforms, 
               caption_transforms,
               captions_per_image=CAPTIONS_PER_IMAGE):
    self.image_transforms = image_transforms
    self.caption_transforms = caption_transforms
    self.images_path = images_path
    self.images_fnames = sorted(os.listdir(images_path))
    self.df = caption_df
    self.captions_per_image = captions_per_image

  def __len__(self):
    return len(self.images_fnames)

  def __getitem__(self, idx):
    fname = self.images_fnames[idx]
    fpath = os.path.join(self.images_path, fname)
    image = Image.open(fpath)
    
    idx *= self.captions_per_image # num of repeats
    captions = list(self.df.iloc[idx:idx+self.captions_per_image]["cleaned"])

    image = self.image_transforms(image)
    captions = [self.caption_transforms(c) for c in captions]

    return image, captions

In [9]:

# Influenced by https://github.com/siddsrivastava/Image-captioning/blob/master/model.py

class FlickrEncoderCNN(nn.Module):
  def __init__(self, 
               embedding_dim, 
               freeze):
    super().__init__()
    # Loading resnet
    resnet = torchvision.models.resnet18(pretrained=True)
    fc_in_features = resnet.fc.in_features

    # Defining layers
    modules = list(resnet.children())[:-1]
    self.resnet = nn.Sequential(*modules)
    self.fc = nn.Linear(fc_in_features, embedding_dim)

    # Freezing if needed
    if freeze:
      for param in self.resnet.parameters():
        param.requires_grad = True
    for param in self.fc.parameters():
      param.requires_grad = True
  
  def forward(self, x):
    x = self.resnet(x)
    x = x.view(x.shape[0], -1)
    x = self.fc(x)
    return x


class FlickrDecoderLSTM(nn.Module):
  def __init__(self, 
                vocab_size,
                embedding_dim,
                hidden_size,
                padding_idx,
                bidirectional=False,
                dropout=0):
    
    super().__init__()
    # Creating embeddings
    self.embed = nn.Embedding(num_embeddings=vocab_size,
                              embedding_dim=embedding_dim, 
                              padding_idx=padding_idx,
                              scale_grad_by_freq=True,
                              sparse=True,
                              )

    self.lstm = nn.LSTM(input_size=embedding_dim,
                        hidden_size=hidden_size,
                        batch_first=True,
                        dropout=dropout,
                        bidirectional=bidirectional)
    
    self.fc = nn.Linear(hidden_size,vocab_size)

  def forward(self, features, caption_seqs):
      caption_seqs = caption_seqs[:,:-1] 
      embeddings = self.embed(caption_seqs)
      total_input = torch.cat((features.unsqueeze(1), embeddings), 1)
      lstm_out, hidden = self.lstm(total_input)
      outputs = self.fc(lstm_out)
      return outputs, hidden


class FlickrNet(nn.Module):
  def __init__(self, 
               vocab_size,
               embedding_dim,
               hidden_size,
               padding_idx,
               end_token_index,
               bidirectional=False,
               dropout=0,
               freeze=True):
    super().__init__()
    encoder = FlickrEncoderCNN(embedding_dim, freeze)
    decoder = FlickrDecoderLSTM(vocab_size, 
                                    embedding_dim, 
                                    hidden_size, 
                                    padding_idx, 
                                    bidirectional, 
                                    dropout)
    self.encoder = encoder
    self.decoder = decoder
    self.modules = nn.ModuleList([encoder,decoder])
    self.end_token_index = end_token_index

  def forward(self, images, caption_seqs):
    features = self.encoder(images)
    outputs = self.decoder(features, caption_seqs)
    return outputs
  
  def predict(self, image, states, max_len = 30):
    with torch.no_grad():
      # Needs to be turned into 1 x Channels x Width x Height
      features = self.encoder(image.unsqueeze(0))
      caption_seq = torch.FloatTensor(features.shape[0], 1, 1)
      for _ in range(max_len):
        outputs, hidden = self.decoder(features, caption_seq)
        predicted_index = F.softmax(outputs).argmax()[..., -1,...].unsqueeze(-1)
        caption_seq = torch.cat((caption_seq, predicted_index), dim=1)
        if predicted_index.item() == self.end_token_index:
          break
    return caption_seq

In [10]:
# This is needed for batches

class RepeatImages:
  def __init__(self, num_repeat=CAPTIONS_PER_IMAGE):
    self.num_repeat = num_repeat

  def __call__(self, batch):
    result = []
    for image, captions in batch:
      for i in range(self.num_repeat):
        result.append((image, captions[i]))
    return result

class PadCaptions:
  def __init__(self, vocab):
    self.pad_idx = vocab[PAD_TOKEN]
  
  def __call__(self, batch):
    captions = []
    images = []
    for im, cap in batch:
      captions.append(cap)
      images.append(im.unsqueeze(0))
    captions = torch.nn.utils.rnn.pad_sequence(captions,
                                    batch_first=True,
                                    padding_value=self.pad_idx)
    images = torch.cat(images, dim=0)
    return images, captions


batch_transforms = transforms.Compose([
     RepeatImages(),
     PadCaptions(vocab),
     ]
)

# Preparing for Training

### Defining training function

In [11]:
def train(model, optimizer, loss_function, train_loader, max_epochs=50):
  scaler = torch.cuda.amp.GradScaler()
  info = []
  for epoch in range(1, max_epochs+1):
    epoch_loss = 0
    n_items = 0
    pbar = tqdm(total=len(train_loader), desc=f"Epoch: {epoch}")
    for batch in train_loader:
      torch.cuda.empty_cache() 
      optimizer.zero_grad()

      images, captions = batch
      images, captions = images.cuda(), captions.cuda()
      
      with torch.cuda.amp.autocast():
        outputs = model(images, captions)
        loss = loss_function(captions, outputs)

      # scaling the loss
      loss = scaler.scale(loss)
      loss.backward()
      scaler.step(optimizer)
      scaler.update()

      epoch_loss += loss.item()
      n_items = captions.numel() # Batchsize
      pbar.set_postfix(batch_loss=f"{epoch_loss/captions.numel():.3f}")
      pbar.update()
    epoch_loss /= n_items()
    info.append(
        {"loss": epoch_loss}
    )
    print(f"Epoch: {epoch}, Loss: {epoch_loss:.3f}")
  return info

### Defining Test Function

In [12]:
def test(model, loss_function, test_loader, batch_transforms):
  scaler = torch.cuda.amp.GradScaler()
  test_loss = 0
  with torch.no_grad():
    for batch in test_loader:
        images, captions = batch
        images, captions = images.cuda(), captions.cuda()

        with torch.cuda.amp.autocast():
          outputs = model(images, captions)
          loss = loss_function(captions, outputs)
        

        test_loss += scaler.scale(loss).item()
        n_items += captions.numel() # Batchsize
        
    test_loss /= n_items()
  print(f"Test Loss: {test_loss:.3f}")
  return test_loss

In [13]:
dataset = FlickrDataset(
    images_path=IMAGE_ROOT, 
    caption_df=caption_df,
    image_transforms=image_transforms, 
    caption_transforms=caption_transforms,
)
test_length = int(TEST_SPLIT * len(dataset))
train_length = len(dataset) - test_length
train_dataset, test_dataset = random_split(dataset, [train_length, test_length])

In [14]:
train_batch_size = 32
test_batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=train_batch_size, collate_fn=batch_transforms)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, collate_fn=batch_transforms)

# Creating models and training

In [15]:
model = FlickrNet(
    vocab_size=len(vocab),
    embedding_dim=EMBEDDING_DIM,
    hidden_size=EMBEDDING_DIM,
    padding_idx=vocab[PAD_TOKEN],
    end_token_index=vocab[END_TOKEN],           
    bidirectional=False,
    dropout=0,
    freeze=True
    ).cuda()
loss_function = nn.CrossEntropyLoss(reduction="sum").cuda()

optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01)




In [16]:
train_info = train(model, optimizer, loss_function, train_loader, max_epochs=50)

HBox(children=(FloatProgress(value=0.0, description='Epoch: 1', max=228.0, style=ProgressStyle(description_wid…

RuntimeError: ignored