In [66]:
from datasets import load_dataset, load_from_disk, Dataset
# from datasets import load_from_disk
from collections import namedtuple
from features import VectorsLoader
import torch
from archs import Sender, Receiver
import egg.core as core
import torch.nn.functional as F

In [67]:
ds_val = load_from_disk("../../../datasets/coco_val_features_resnet_152")#.select(range(100))
ds_val

Dataset({
    features: ['coco_url', 'captions', 'image_id', 'features'],
    num_rows: 5000
})

In [68]:
ds_val['coco_url']

Column(['http://images.cocodataset.org/val2017/000000397133.jpg', 'http://images.cocodataset.org/val2017/000000037777.jpg', 'http://images.cocodataset.org/val2017/000000252219.jpg', 'http://images.cocodataset.org/val2017/000000087038.jpg', 'http://images.cocodataset.org/val2017/000000174482.jpg'])

In [69]:
checkpoint_path = "/home/elena/emcomm/emcomm_captions/checkpoints/full_game/best_epoch_6.pt"

In [70]:
def loss(
    _sender_input, _message, _receiver_input, receiver_output, _labels, _aux_input
):
    acc = (receiver_output.argmax(dim=1) == _labels).detach().float()
    loss = F.cross_entropy(receiver_output, _labels, reduction="none")
    return loss, {"acc": acc}

def init_game(checkpoint_path: str):
    checkpoint = torch.load(checkpoint_path, weights_only=False)
    OptsNamedTuple = namedtuple('Opts', checkpoint['opts'].keys())
    opts = OptsNamedTuple(*checkpoint['opts'].values())

    data_loader = VectorsLoader(
        perceptual_dimensions=opts.perceptual_dimensions,
        n_distractors=opts.n_distractors,
        batch_size=opts.batch_size,
        train_samples=opts.train_samples,
        validation_samples=opts.validation_samples,
        test_samples=opts.test_samples,
        shuffle_train_data=opts.shuffle_train_data,
        dump_data_folder=opts.dump_data_folder,
        load_data_path=opts.load_data_path,
        seed=opts.data_seed,
    )
    print(f"Data loaded. Number of features: {data_loader.n_features}")
    sender_orig = Sender(n_features=data_loader.n_features, n_hidden=opts.sender_hidden)

    receiver = Receiver(
        n_features=data_loader.n_features, linear_units=opts.receiver_hidden
    )

    if opts.mode.lower() == "gs":
        sender = core.RnnSenderGS(
            sender_orig,
            opts.vocab_size,
            opts.sender_embedding,
            opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
            temperature=opts.temperature,
        )

        receiver = core.RnnReceiverGS(
            receiver,
            opts.vocab_size,
            opts.receiver_embedding,
            opts.receiver_hidden,
            cell=opts.receiver_cell,
        )

        # game = core.SenderReceiverRnnGS(sender, receiver, loss)
    else:
        raise NotImplementedError(f"Unknown training mode, {opts.mode}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sender.to(device)

    return sender

In [71]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [72]:
sender = init_game(checkpoint_path)
sender

Data loaded. Number of features: 1000


RnnSenderGS(
  (agent): Sender(
    (fc1): Linear(in_features=1000, out_features=50, bias=True)
  )
  (hidden_to_output): Linear(in_features=50, out_features=70, bias=True)
  (embedding): Linear(in_features=70, out_features=10, bias=True)
  (cell): RNNCell(10, 50)
)

In [73]:
def get_message_batch(sender, input_batch):
    with torch.no_grad():
        messages_probs = sender(torch.tensor(input_batch['features']).to(device))
        messages = messages_probs.argmax(dim=-1)
        
    # return messages
    return {'message': messages.cpu(), 'probs': messages_probs.cpu()}

In [74]:
new_ds = ds_val.map(
    lambda example: {
        # "captions": example["captions"],
        "message": get_message_batch(sender, example)['message'],
    },
    batched=True,
    batch_size=64,
    # num_proc=24
)

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Map: 100%|██████████| 5000/5000 [00:10<00:00, 477.22 examples/s]


In [75]:
def truncate_message(example):
    key = "message"
    msg = example[key]

    seq = msg
    idx = seq.index(0)
    seq = seq[:idx]

    # return {'message_truncated': seq}
    return seq

In [76]:
new_ds = new_ds.map(
    lambda example: {
        'message_truncated': truncate_message(example),
    }
)

new_ds

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Map: 100%|██████████| 5000/5000 [00:00<00:00, 6189.01 examples/s]


Dataset({
    features: ['coco_url', 'captions', 'image_id', 'features', 'message', 'message_truncated'],
    num_rows: 5000
})

In [77]:
new_ds.save_to_disk("../../../datasets/coco_val_msg_captions")

Saving the dataset (0/1 shards):   0%|          | 0/5000 [00:00<?, ? examples/s]

Saving the dataset (1/1 shards): 100%|██████████| 5000/5000 [00:00<00:00, 10862.91 examples/s]


In [78]:
# new_ds['message_truncated'][0]

In [79]:
# sum([len(msg) for msg in new_ds['message_truncated']])/100
