In [None]:
%matplotlib inline
import os
import cv2
import gc
import numpy as np
import pandas as pd
import itertools
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
import timm
import json
import logging
from pathlib import Path
import random
import tarfile
import tempfile
import warnings
import timm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
#import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torchvision
import fasttext
import pandas_path
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer

In [None]:
# Defining the path for dataset and json files accompanying the dataset
data_dir = Path.cwd().parent / "Data" / "hateful_memes"
img_tar_path = data_dir / "img.tar.gz"
train_path = data_dir / "train.jsonl"
dev_path = data_dir / "dev_seen.jsonl"
test_path = data_dir / "test_seen.jsonl"

In [None]:
dev_path_frame = pd.read_json(dev_path, lines=True)
dev_path_frame

In [None]:
train_path

In [None]:
train_samples_frame = pd.read_json(train_path, lines=True)
train_samples_frame

In [None]:
#loading the DistilBERT tokenizer.
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [None]:
def make_train_valid_dfs(path):
    dataframe = pd.read_json(path,lines=True)
    max_id = dataframe["id"].max() + 1 if not False else 100
    image_ids = np.arange(0, max_id)
    np.random.seed(42) #using the same seed value to make sure we train on the same training set across differnt approaches. 
    valid_ids = np.random.choice(
        image_ids, size=int(0.2 * len(image_ids)), replace=False
    )
    train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
    train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
    valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
    return train_dataframe, valid_dataframe

In [None]:
train,val=make_train_valid_dfs(train_path)

In [None]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, captions,labels, tokenizer, transforms):
        """
        image_filenames and cpations must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names 
        """

        self.image_filenames = image_filenames
        self.captions = list(captions)
        self.labels = list(labels)
        self.encoded_captions = tokenizer(
            list(captions), padding=True, truncation=True, max_length=200
        )
        self.transforms = transforms

    def __getitem__(self, idx):
        item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
        }
        label = torch.Tensor([list(self.labels)]).long().squeeze()
        image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transforms(image=image)['image']
        item['image'] = torch.tensor(image).permute(2, 0, 1).float()
        item['caption'] = self.captions[idx]
        item['label'] = self.labels[idx]

        return item


    def __len__(self):
        return len(self.captions)

In [None]:
import albumentations as A
transforms=A.Compose([A.Resize(384, 384, always_apply=True),
                      A.Normalize(max_pixel_value=255.0, always_apply=True),])
dataset = CLIPDataset(train["img"].values,train["text"].values,train["label"].values,tokenizer=tokenizer,transforms=transforms,)
dataloader = torch.utils.data.DataLoader(dataset,batch_size=4,num_workers=3,shuffle=True)

val_dataset = CLIPDataset(val["img"].values,val["text"].values,val["label"].values,tokenizer=tokenizer,transforms=transforms,)
val_dataloader = torch.utils.data.DataLoader(val_dataset,batch_size=4,num_workers=3,shuffle=False)


In [None]:
image_path = "/scratch/ps4364/HM/Data/hateful_memes"
captions_path = "/scratch/ps4364/HM/Data/hateful_memes"
class CFG:
    debug = False
    image_path = image_path
    captions_path = captions_path
    batch_size = 32
    num_workers = 2
    head_lr = 1e-3
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    epochs = 4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_name = 'resnet50'
    image_embedding = 2048
    text_encoder_model = "distilbert-base-uncased"
    text_embedding = 768
    text_tokenizer = "distilbert-base-uncased"
    max_length = 200

    pretrained = True # for both image encoder and text encoder
    trainable = True # for both image encoder and text encoder
    temperature = 1.0

    # image size
    size = 384

    # for projection head; used for both image and text encoders
    num_projection_layers = 1
    projection_dim = 256 
    dropout = 0.1
    
class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=CFG.temperature,
        image_embedding=CFG.image_embedding,
        text_embedding=CFG.text_embedding,
    ):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature

    def forward(self, batch):
        # Getting Image and Text Features
        image_features = self.image_encoder(batch["image"])
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        # Getting Image and Text Embeddings (with same dimension)
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        # Calculating the Loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
        return loss.mean()

class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CFG.projection_dim,
        dropout=CFG.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)
    
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x
    
class TextEncoder(nn.Module):
    def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
        else:
            self.model = DistilBertModel(config=DistilBertConfig())
            
        for p in self.model.parameters():
            p.requires_grad = trainable

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]
class ImageEncoder(nn.Module):
    """
    Encode images to a fixed size vector
    """

    def __init__(
        self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
    ):
        super().__init__()
        self.model = timm.create_model(
            model_name, pretrained, num_classes=0, global_pool="avg"
        )
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)
    


In [None]:
class LanguageAndVisionConcat(torch.nn.Module):
    def __init__(self,language_module,vision_module,
                 language_feature_dim=300,vision_feature_dim=300,fusion_output_size=512,dropout_p=0.1,num_classes=2):
        super(LanguageAndVisionConcat, self).__init__()
        self.language_module = language_module
        self.language_adder=nn.Linear(768,300)
        self.vision_module = vision_module
        self.vision_adder=nn.Linear(512,300)
        self.fusion = torch.nn.Linear(
            in_features=(language_feature_dim + vision_feature_dim), 
            out_features=fusion_output_size
        )
        self.fc = torch.nn.Linear(
            in_features=fusion_output_size, 
            out_features=num_classes
        )
        self.dropout = torch.nn.Dropout(dropout_p)
        
    def forward(self, x):
        language_op1=model1.text_encoder(x['input_ids'].cuda(),x['attention_mask'])
        language_op = self.language_adder(language_op1)
        
        vision_op1=model1.image_encoder(x['image'])
        vision_op = self.vision_adder(vision_op1)
        combined = torch.cat(
            [language_op, vision_op], dim=1
        )
        fused = self.dropout(
            torch.nn.functional.relu(
            self.fusion(combined)
            )
        )
        logits = self.fc(fused)
        pred = torch.nn.functional.softmax(logits)
        return pred

In [None]:
model1=torch.load('./vitb16-r50-CNNPART-CASS-BERT-384-logits-FINAL.pt')
#Loading the pretrained model that we would be finetuning with full-supervision. 
model = LanguageAndVisionConcat(model1.text_encoder,model1.image_encoder).cuda()
optimizer = torch.optim.AdamW(
                model.parameters(), 
                lr=0.00005)
        
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
loss_fn=torch.nn.CrossEntropyLoss()
from sklearn.metrics import *
import math
def train(model,
    optimizer,
    scheduler,
    loss_fn,epochs):
    best_train_loss=math.inf
    best_val_loss=math.inf
    for epoch in tqdm(range(epochs)):
        labels=[]
        predictions=[]
        model.train()
        total_loss=0
        for x in dataloader:
            batch = {k: v.to(CFG.device) for k, v in x.items() if k != "caption"}
            output=model(batch)
            loss = loss_fn(output,batch['label'])
            output=torch.argmax(output,dim=1)
            labels.append(batch['label'].cpu().detach().tolist())
            predictions.append(output.cpu().detach().tolist())
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss+=loss.item()
        total_loss=total_loss/len(dataloader)
        fpr, tpr, thresholds=roc_curve(np.asarray(labels[0]),np.asarray(predictions[0]))
        train_roc=auc(fpr, tpr)
        current_train_acc=accuracy_score(np.asarray(labels[0]),np.asarray(predictions[0]))
        print('For this epoch the loss was:',total_loss,'AUC value:',train_roc,'And acc:',current_train_acc)
        if best_train_loss>total_loss:
            #Run this loop only when new loss is less than the previosu minimoum loss 
            best_train_loss=total_loss
            #update the loss value to a new minimum
            print('Validating!')
            model.eval()
            val_loss=0
            val_labels=[]
            val_predictions=[]
            for x in val_dataloader:
                batch = {k: v.to(CFG.device) for k, v in x.items() if k != "caption"}
                output=model(batch)
                val_loss = loss_fn(output,batch['label'].cuda())
                val_output=torch.argmax(output,dim=1)
                val_labels.append(batch['label'].cpu().detach().tolist())
                val_predictions.append(val_output.cpu().detach().tolist())
                val_loss+=val_loss.item()
            val_loss=val_loss/len(val_dataloader)
            fpr, tpr, thresholds=roc_curve(np.asarray(val_labels[0]),np.asarray(val_predictions[0]))
            val_roc=auc(fpr, tpr)
            current_val_acc=accuracy_score(np.asarray(val_labels[0]),np.asarray(val_predictions[0]))
            print('For this epoch the loss was:',val_loss,'AUC value:',val_roc,'And acc:',current_val_acc)
            if best_val_loss>val_loss:
                #Run this loop only when we have a new minimum for validation loss.
                best_val_loss=val_loss
                #update the loss to reflect the new minimum.
                print('Saving Model!')
                torch.save(model,'./vitb16-r50-CNNPART-CASS-BERT-384-logits-FINAL-FT.pt')

In [None]:
warnings.filterwarnings("ignore")
logging.getLogger().setLevel(logging.WARNING)
train(model,optimizer,optimizer,loss_fn,50)