In [1]:
!pip install torch torchvision transformers bert-score evaluate tqdm matplotlib numpy sklearn evaluate rouge_score

Defaulting to user installation because normal site-packages is not writeable
Collecting sklearn
  Using cached sklearn-0.0.post12.tar.gz (2.6 kB)
  Preparing metadata (setup.py) ... [?25lerror
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m [31m[15 lines of output][0m
  [31m   [0m The 'sklearn' PyPI package is deprecated, use 'scikit-learn'
  [31m   [0m rather than 'sklearn' for pip commands.
  [31m   [0m 
  [31m   [0m Here is how to fix this error in the main use cases:
  [31m   [0m - use 'pip install scikit-learn' rather than 'pip install sklearn'
  [31m   [0m - replace 'sklearn' by 'scikit-learn' in your pip requirements files
  [31m   [0m   (requirements.txt, setup.py, setup.cfg, Pipfile, etc ...)
  [31m   [0m - if the 'sklearn' package is used by one of your dependencies,
  [31m   [0m   it would be great if you take some tim

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
from tqdm import trange, tqdm
import numpy as np
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import precision_score, recall_score, f1_score
from models.stegastamp_wm import StegaStampDecoder, StegaStampEncoder
import collections
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor,AutoTokenizer
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from evaluate import load

meteor = load("meteor")
rouge = load("rouge")
bleu = load("bleu")

[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/varunchitturi/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/varunchitturi/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /Users/varunchitturi/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [3]:
DATASET_SIZE = 600
IMAGE_SIZE = 256
NUM_BITS = 48
IMAGE_CHANNELS = 3

if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'


train_dataset = datasets.CocoCaptions(root = './data/images/train',
                        annFile = './data/annotations/train_captions.json',
                        transform=transforms.Compose([
                            transforms.Resize(IMAGE_SIZE),
                            transforms.CenterCrop(IMAGE_SIZE),
                            transforms.ToTensor()
                        ]))
val_dataset = datasets.CocoCaptions(root = './data/images/val',
                        annFile = './data/annotations/val_captions.json',
                        transform=transforms.Compose([
                            transforms.Resize(IMAGE_SIZE),
                            transforms.CenterCrop(IMAGE_SIZE),
                            transforms.ToTensor()
                        ]))


loading annotations into memory...
Done (t=0.28s)
creating index...
index created!
loading annotations into memory...
Done (t=0.14s)
creating index...
index created!


In [4]:
signature = torch.randint(0, 2, (1, NUM_BITS), device=device).float()
wm_encoder = StegaStampEncoder(
    IMAGE_SIZE,
    IMAGE_CHANNELS,
    NUM_BITS,
)
wm_encoder_load = torch.load('models/wm_stegastamp_encoder.pth', map_location=device, weights_only=True)
if type(wm_encoder_load) is collections.OrderedDict:
    wm_encoder.load_state_dict(wm_encoder_load)
else:
    wm_encoder = wm_encoder_load

wm_decoder = StegaStampDecoder(
    IMAGE_SIZE,
    IMAGE_CHANNELS,
    NUM_BITS,
)
wm_decoder_load = torch.load('models/wm_stegastamp_decoder.pth', map_location=device, weights_only=True)
if type(wm_decoder_load) is collections.OrderedDict:
    wm_decoder.load_state_dict(wm_decoder_load)
else:
    wm_encoder = wm_encoder_load

wm_encoder.to(device)
wm_decoder.to(device)


StegaStampDecoder(
  (decoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (13): ReLU()
  )
  (dense): Sequential(
    (0): Linear(in_features=8192, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=48, bias=True)
  )
)

In [5]:
class CocoCaptionMixedWMDataset(Dataset):
    def __init__(self, signature, coco_dataset, num_images):
        super(CocoCaptionMixedWMDataset, self).__init__()
        self.coco_dataset = coco_dataset
        self.dataset = []
        self.images = []
        self.captions = []
        for i in trange(num_images):
            try:
                image, caption = self.coco_dataset[i]
                image = image.to(device).float()
                wm_image = wm_encoder(signature.unsqueeze(0).to(device), image.unsqueeze(0).to(device))
                self.dataset.append((wm_image, signature))
                self.dataset.append((image.unsqueeze(0).to(device), caption))
                self.images.append(wm_image)
                self.images.append(image.unsqueeze(0).to(device))
                self.captions.append(signature)
                self.captions.append(caption)
                
            
            except Exception as e:
                print(e)
        self.images = torch.stack(self.images)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

In [6]:
train_dataset = CocoCaptionMixedWMDataset(signature, train_dataset, DATASET_SIZE)
val_dataset = CocoCaptionMixedWMDataset(signature, val_dataset, int(DATASET_SIZE/2))

100%|██████████| 600/600 [00:08<00:00, 67.37it/s]
100%|██████████| 300/300 [00:08<00:00, 35.51it/s]


In [13]:

pred_captions = []
true_captions = []
tp = 1e-10
fp = 1e-10
fn = 1e-10
tn = 1e-10
for i in trange(len(val_dataset)):
    image, caption = val_dataset[i]
    closest_idx = torch.argmin(torch.sum(torch.sqrt((train_dataset.images - image) ** 2), dim=(1,2,3,4))).item()
    pred_caption = train_dataset.captions[closest_idx]
    if type(caption) is torch.Tensor:
        caption = "".join(map(str, caption.int().tolist()))
        if type(pred_caption) is torch.Tensor:
            pred_caption = "".join(map(str, pred_caption.int().tolist()))
            tp += 1
        else:
            fn += 1
    else:
        if type(pred_caption) is torch.Tensor:
            pred_caption = "".join(map(str, pred_caption.int().tolist()))
            fp += 1
        else:
            fn += 1
            
        
    pred_captions.append(pred_caption)
    true_captions.append(caption)

    
    

100%|██████████| 600/600 [00:32<00:00, 18.34it/s]


In [14]:
def compute_evaluation_metrics(generated_captions, reference_captions):
    """
    Compute METEOR and ROUGE scores for generated captions.

    Args:
        generated_captions (list of str): Captions generated by the model.
        reference_captions (list of list of str): Reference captions for each image.

    Returns:
        dict: Dictionary containing METEOR, BLEU and ROUGE scores.
    """
    # Load METEOR and ROUGE metrics
    # Compute METEOR score
    meteor_score = meteor.compute(predictions=generated_captions, references=reference_captions)

    # Compute ROUGE scores
    rouge_score = rouge.compute(predictions=generated_captions, references=reference_captions)

    bleu_score = bleu.compute(predictions=generated_captions, references=reference_captions)

    # Aggregate results
    results = {
        "meteor": meteor_score["meteor"],
        "bleu": bleu_score["bleu"],
        "rouge1": rouge_score["rouge1"],
        "rouge2": rouge_score["rouge2"],
        "rougeL": rouge_score["rougeL"],
    }

    return results

In [15]:
print(compute_evaluation_metrics(pred_captions, true_captions))
print("Precision: ", tp/(tp+fp))
print("Recall: ", tp/(tp+fn))
print("F-1: ", 2*tp/(2*tp+fp+fn))

{'meteor': 0.49009940085471404, 'bleu': 0.4857747833015032, 'rouge1': 0.4665265127817766, 'rouge2': 0.43441502253911024, 'rougeL': 0.45614368503642344}
Precision:  0.5493562231759445
Recall:  0.6564102564101763
F-1:  0.5981308411214494
