In [7]:
from __future__ import print_function, division
import os
import torch
from typing import Callable, Dict, List, Optional, Tuple, Union
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.datasets as dset
import random
from transformers import EncoderDecoderModel, GPT2Tokenizer, ViTFeatureExtractor
from tqdm.auto import tqdm
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import torchvision.datasets as dset
import multiprocessing as mp
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import pytorch_lightning as pl
from deepspeed.ops.adam import FusedAdam
from pytorch_lightning.loggers import WandbLogger
# import wandb

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

<contextlib.ExitStack at 0x2a5b23a60>

## Config

In [8]:
VIT_MODEL = "google/vit-base-patch16-224-in21k"
GPT2 = "gpt2"
DISTIL_GPT2 = "distilgpt2"

DATA_PATH = "Processed_Frames/"

# "/content/drive/MyDrive/CS640 Project/Processed Frames/"
ANNOTATION_PATH = "/content/drive/MyDrive/CS640 Project/Y.json"

MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)

TRAIN_PCT = 0.95
NUM_WORKERS = mp.cpu_count()
BATCH_SIZE = 8
EPOCHS = 3
LR = 1e-4
IMAGE_SIZE = (224, 224)

MAX_TEXT_LENGTH = 32

LABEL_MASK = -100

TOP_K = 1000
TOP_P = 0.95

In [10]:
y = pd.read_csv('Y.csv')
# y.to_json('/content/drive/MyDrive/CS640 Project/Y.json')

In [11]:
y

Unnamed: 0,label
0,Caroline Ingalls is a ZOMBIE!! -- New subreddi...
1,Caroline Ingalls is a ZOMBIE!! -- New subreddi...
2,Caroline Ingalls is a ZOMBIE!! -- New subreddi...
3,Caroline Ingalls is a ZOMBIE!! -- New subreddi...
4,Caroline Ingalls is a ZOMBIE!! -- New subreddi...
...,...
2049,Elpedroym0.png
2050,Elpedroym1.png
2051,Elpedroym2.png
2052,Elpedroym3.png


## Data

In [56]:
class GifDataset(Dataset):
    """Gif Caption dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.labels = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        print(idx)
        print(self.labels.iloc[idx, 0])
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.labels.iloc[idx, 0])
        print(self.labels.iloc[idx, 0])

        image = io.imread(img_name)
        
        labels = self.labels.iloc[idx, 0]
        labels = np.array([labels])
        
        # labels = labels.astype('float').reshape(-1, 2)
        sample = {'image': image, 'labels': labels, 'idx': idx}
        
        if self.transform:
            # print(sample['labels'].shape)
            sample = self.transform(sample)
            
        return sample

In [57]:
class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, idx = sample['image'], sample['idx']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        # landmarks = landmarks * [new_w / w, new_h / h]
        # print(landmarks)

        return {'image': img, 'idx': idx}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, idx = sample['image'], sample['idx']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        # print(landmarks)
        # return {'image': torch.from_numpy(image),
        #         'labels': landmarks}
        # print(landmarks)
        return {'image': torch.from_numpy(image), 'idx': idx}

def show_gifs(image, labels):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(labels[:, 0], labels[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)


scale = Rescale(512)
composed = transforms.Compose([Rescale((224,224))])

In [66]:
image_path = 'Processed_Frames/'
gif_data = GifDataset(csv_file='Y.csv',
                                    root_dir= image_path)

In [67]:
transformed_dataset = GifDataset(csv_file='Y.csv',
                                           root_dir=image_path,
                                           transform=transforms.Compose([
                                               ToTensor()
                                           ]))

In [68]:
train_len = int(TRAIN_PCT * len(transformed_dataset))
train_data, valid_data = random_split(transformed_dataset, [train_len, len(transformed_dataset) - train_len])
train_dl = DataLoader(
    train_data, 
    BATCH_SIZE, 
    pin_memory=True, 
    shuffle=True, 
    num_workers=0, 
    drop_last=True
)
valid_dl = DataLoader(
    valid_data, 
    BATCH_SIZE, 
    pin_memory=True, 
    shuffle=False, 
    num_workers=0, 
    drop_last=False
)

# # images, captions = next(iter(train_dl))
# images, idx = next(iter(train_dl))
# images = images
# images.shape, images.min(), images.max(), images.mean(), images.std()

In [69]:
for i in train_dl:
    images = i['image']
    idx = i['idx']

1106
When Elmo Is Talking You Shut Your Trap37.png
When Elmo Is Talking You Shut Your Trap37.png
130
Neighbor kids causing mischief...47.png
Neighbor kids causing mischief...47.png
1153
When Elmo Is Talking You Shut Your Trap84.png
When Elmo Is Talking You Shut Your Trap84.png
1863
Funny Pasta16.png
Funny Pasta16.png
1116
When Elmo Is Talking You Shut Your Trap47.png
When Elmo Is Talking You Shut Your Trap47.png
1449
A Murder Knife Mystery15.png
A Murder Knife Mystery15.png
1964
The end of the line for Microsoft Customer Support.0.png
The end of the line for Microsoft Customer Support.0.png
1349
Let’s set off a fire cracker near some open windows7.png
Let’s set off a fire cracker near some open windows7.png
1499
gottem6.png
gottem6.png
837
Skte17.png
Skte17.png
1043
Iran coming up..2.png
Iran coming up..2.png
1261
Fly Me To The Moon17.png
Fly Me To The Moon17.png
1457
Not giving AF0.png
Not giving AF0.png
1285
Fly Me To The Moon41.png
Fly Me To The Moon41.png
1219
When you got pussy on

FileNotFoundError: No such file: '/Users/aakashbhatnagar/Documents/masters/CS 640 AI/Project/Processed_Frames/When you got pussy on your mind but you're stuck in court2.png'

In [62]:
y[y['label'] == 'If the veggies can do it then why can’t you.3.png']

Unnamed: 0,label
1416,If the veggies can do it then why can’t you.3.png


In [70]:
len(os.listdir(image_path)

2054

In [None]:
j = 0
for i in transformed_dataset:
    print(i)

In [78]:
idx

'idx'

In [77]:
n = {'a':[1,2], 'b': [4,5]}
next(iter(n))

'a'

In [65]:
np.array(['Neighbor kids causing mischief...136.png'])

array(['Neighbor kids causing mischief...136.png'], dtype='<U40')

## Model

In [None]:
# make sure GPT2 appends EOS in begin and end
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
    outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
    return outputs
    
GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens
gpt2_tokenizer = GPT2Tokenizer.from_pretrained(DISTIL_GPT2)
# set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id
gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token

gpt2_tokenizer_fn = lambda x: gpt2_tokenizer(
    x,
    max_length=MAX_TEXT_LENGTH,
    truncation=True,
    padding=True,
    return_tensors="pt",
)

vit2gpt2 = EncoderDecoderModel.from_encoder_decoder_pretrained(VIT_MODEL, DISTIL_GPT2)

## Nucleus Sampling
[Paper](https://arxiv.org/pdf/1904.09751.pdf)

In [None]:
def top_k_top_p_filtering(
    next_token_logits: torch.FloatTensor,
    top_k: Optional[float]=None, 
    top_p: Optional[float]=None,
    device: Union[str, torch.device]="cpu",
) -> torch.FloatTensor:
    if top_k is None:
        top_k = next_token_logits.shape[-1]
    if top_p is None:
        top_p = 1.0
        
    p, largest_p_idx = F.softmax(next_token_logits, dim=-1).topk(top_k, dim=-1)
    cumulative_p = p.cumsum(dim=-1)
    threshold_repeated = top_p + torch.zeros((len(p),1)).to(device)
    idx = torch.searchsorted(cumulative_p, threshold_repeated).clip(max=top_k-1).squeeze()
    cutoffs = cumulative_p[torch.arange(len(cumulative_p)), idx]
    censored_p = (cumulative_p <= cutoffs[:, None]) * p
    renormalized_p = censored_p / censored_p.sum(dim=-1, keepdims=True)
    
    final_p = torch.zeros_like(next_token_logits)
    row_idx = torch.arange(len(p)).unsqueeze(1).repeat(1,top_k).to(device)
    final_p[row_idx, largest_p_idx] = renormalized_p.to(final_p.dtype)

    return final_p

def generate_sentence_from_image(model, encoder_outputs, tokenizer, max_text_length: int, device)-> List[str]:
    generated_so_far = torch.LongTensor([[tokenizer.bos_token_id]]*len(encoder_outputs.last_hidden_state)).to(device)
    with torch.no_grad():
        for _ in tqdm(range(max_text_length)):
            attention_mask = torch.ones_like(generated_so_far)
            decoder_out = model(
                decoder_input_ids=generated_so_far, 
                decoder_attention_mask=attention_mask,
                encoder_outputs=encoder_outputs
            )

            next_token_logits = decoder_out["logits"][:, -1, :]
            filtered_p = top_k_top_p_filtering(next_token_logits, top_k=TOP_K, top_p=TOP_P, device=device)
            next_token = torch.multinomial(filtered_p, num_samples=1)
            generated_so_far = torch.cat((generated_so_far, next_token), dim=1)

    return [tokenizer.decode(coded_sentence) for coded_sentence in generated_so_far]

## Training Module (PyTorch Lightning)

In [None]:
class LightningModule(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        tokenizer,
        lr: float,
    ):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.lr = lr
        
        for name, param in self.model.named_parameters():
            if "crossattention" not in name:
                param.requires_grad = False
        
    def common_step(self, batch: Tuple[torch.FloatTensor, List[str]]) -> torch.FloatTensor:
        images, captions = batch
        tokenized_captions = {
            k: v.to(self.device) for k, v in 
            self.tokenizer(
                captions,
                max_length=MAX_TEXT_LENGTH,
                truncation=True,
                padding=True,
                return_tensors="pt",
            ).items()
        }
        labels = tokenized_captions["input_ids"].clone()
        labels[tokenized_captions["attention_mask"]==0] = LABEL_MASK
        encoder_outputs = self.model.encoder(pixel_values=images)
        outputs = self.model(
            encoder_outputs=encoder_outputs,
            decoder_input_ids=tokenized_captions["input_ids"],
            decoder_attention_mask=tokenized_captions["attention_mask"],
            labels=labels,
            return_dict=True,
        )
        
        return outputs["loss"]
    
    def training_step(self, batch: Tuple[torch.FloatTensor, List[str]], batch_idx: int) -> torch.FloatTensor:
        loss = self.common_step(batch)
        self.log(name="Training loss", value=loss, on_step=True, on_epoch=True)
        
        return loss
        
    def validation_step(self, batch: Tuple[torch.FloatTensor, List[str]], batch_idx: int):
        loss = self.common_step(batch)
        self.log(name="Validation loss", value=loss, on_step=True, on_epoch=True)

        images, actual_sentences = batch
        
        if batch_idx == 0:
            encoder_outputs = self.model.encoder(pixel_values=images.to(self.device))
            generated_sentences = generate_sentence_from_image(
                self.model, 
                encoder_outputs, 
                self.tokenizer, 
                MAX_TEXT_LENGTH,
                self.device
            )
            images = [wandb.Image(transforms.ToPILImage()(descale(image))) for image in images]
            data = list(map(list, zip(images, actual_sentences, generated_sentences)))
            columns = ["Images", "Actual Sentence", "Generated Sentence"]
            table = wandb.Table(data=data, columns=columns)
            self.logger.experiment.log({f"epoch {self.current_epoch} results": table})
                        
    def on_after_backward(self):
        if self.trainer.global_step % 50 == 0:  # don't make the tf file huge
            for name, param in self.model.named_parameters():
                if "weight" in name and not "norm" in name and param.requires_grad:
                    self.logger.experiment.log(
                        {f"{name}_grad": wandb.Histogram(param.grad.detach().cpu())}
                    )
                    self.logger.experiment.log(
                        {f"{name}": wandb.Histogram(param.detach().cpu())}
                    )

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)
            

In [None]:
!mkdir -p /kaggle/working/logs
lightning_module = LightningModule(
    vit2gpt2,
    gpt2_tokenizer,
    LR
)

trainer = pl.Trainer(
    max_epochs=EPOCHS,
    gpus=torch.cuda.device_count(),
    gradient_clip_val=1.0,
    logger=WandbLogger("Frozen", "/kaggle/working/logs/", project="Vit2GPT2"),
    precision=16,
    num_sanity_val_steps=0,
)
trainer.fit(lightning_module, train_dl, valid_dl)