In [None]:
import torch.nn as nn
from datasets import load_dataset, load_from_disk, Dataset
# from datasets import load_from_disk
from collections import namedtuple
from features import VectorsLoader
import torch
from archs import Sender, Receiver
import egg.core as core
import torch.nn.functional as F
import sacrebleu

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers=2, pad_id=None):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_id)
        self.rnn = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, batch_first=True)

    def forward(self, src):
        # src: [B, T]
        embedded = self.emb(src)               # [B, T, emb_dim]
        outputs, (h, c) = self.rnn(embedded)  # outputs ignored for vanilla seq2seq
        return h, c                            # [num_layers, B, hid_dim]

    
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers=2, pad_id=None):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_id)
        self.rnn = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hid_dim, vocab_size)

    def forward(self, tgt, h, c):
        # tgt: [B, T] (with <sos> prepended)
        embedded = self.emb(tgt)               # [B, T, emb_dim]
        outputs, (h, c) = self.rnn(embedded, (h, c))
        logits = self.fc(outputs)              # [B, T, vocab_size]
        return logits, h, c


In [3]:
from transformers import AutoTokenizer
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
tgt_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tgt_pad_id = tgt_tokenizer.pad_token_id

In [4]:
from torch.nn.utils.rnn import pad_sequence
PAD_ID = 70

def collate(batch):
    src = [torch.tensor(b['message_truncated'], dtype=torch.long) for b in batch]
    src = pad_sequence(
        src,
        batch_first=True,
        padding_value=PAD_ID
    )

    tgt = tgt_tokenizer(
        [b['captions'][0] for b in batch],
        padding=True,
        return_tensors="pt"
    )["input_ids"]

    return src, tgt

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
encoder = Encoder(
    vocab_size=70+1,  # +1 for PAD
    emb_dim=256,
    hid_dim=512,
    pad_id=PAD_ID
).to(device)

decoder = Decoder(
    vocab_size=len(tgt_tokenizer.vocab),
    emb_dim=256,
    hid_dim=512,
    pad_id=tgt_pad_id
).to(device)

In [7]:
dataset = load_from_disk("../../../datasets/coco_train_msg_captions")
val_dataset = load_from_disk("../../../datasets/coco_val_msg_captions")

print(type(dataset['message_truncated']))
print(type(dataset['message_truncated'][0][0]))

<class 'datasets.arrow_dataset.Column'>
<class 'int'>


In [8]:
batch_size = 512
loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate
)

In [9]:
encoder

Encoder(
  (emb): Embedding(71, 256, padding_idx=70)
  (rnn): LSTM(256, 512, num_layers=2, batch_first=True)
)

In [10]:
criterion = nn.CrossEntropyLoss(ignore_index=tgt_pad_id)
enc_opt = torch.optim.Adam(encoder.parameters(), lr=1e-2)
dec_opt = torch.optim.Adam(decoder.parameters(), lr=1e-2)

encoder.train()
decoder.train()

Decoder(
  (emb): Embedding(30522, 256, padding_idx=0)
  (rnn): LSTM(256, 512, num_layers=2, batch_first=True)
  (fc): Linear(in_features=512, out_features=30522, bias=True)
)

In [11]:
num_epochs = 10

In [12]:
import wandb
config = {
    "device": str(device),
    "pad_id": PAD_ID,
    "tgt_pad_id": tgt_pad_id,
    "enc_vocab_size": encoder.emb.num_embeddings,
    "dec_vocab_size": decoder.emb.num_embeddings,
    "emb_dim": encoder.emb.embedding_dim,
    "hid_dim": encoder.rnn.hidden_size,
    "enc_num_layers": encoder.rnn.num_layers,
    "dec_num_layers": decoder.rnn.num_layers,
    "encoder_pad_idx": encoder.emb.padding_idx,
    "decoder_pad_idx": decoder.emb.padding_idx,
    "batch_size": getattr(loader, "batch_size", None),
    "lr_enc": enc_opt.param_groups[0]["lr"],
    "lr_dec": dec_opt.param_groups[0]["lr"],
    "optim_enc": type(enc_opt).__name__,
    "optim_dec": type(dec_opt).__name__,
    "optim_enc_betas": enc_opt.param_groups[0]["betas"],
    "optim_dec_betas": dec_opt.param_groups[0]["betas"],
    "criterion": type(criterion).__name__,
    "dataset_len": len(dataset),
    "dataset_features": list(dataset.features.keys()),
    "num_epochs": num_epochs,
    # "sos_id": sos_id,
    # "eos_id": eos_id,
    "tgt_tokenizer": getattr(tgt_tokenizer, "name_or_path", str(tgt_tokenizer)),
}

# ensure the subsequent wandb.init call will merge this config into the run
_orig_wandb_init = wandb.init
def _wandb_init_with_config(*args, **kwargs):
    run = _orig_wandb_init(*args, **kwargs)
    try:
        wandb.config.update(config)
    except Exception:
        pass
    return run
wandb.init = _wandb_init_with_config

print("Prepared wandb config:", config)
# wandb.init(project=project, id=run_id, name=run_name, **kwargs)
wandb.init(
    project='EmComm-Caption-Translator',
    name="find_max_batch",
    config={
        "emb_dim": 256,
        "hid_dim": 512,
        # "batch_size": 32,
        "lr": 3e-3,
        "num_epochs": num_epochs,
    }
)


Prepared wandb config: {'device': 'cuda', 'pad_id': 70, 'tgt_pad_id': 0, 'enc_vocab_size': 71, 'dec_vocab_size': 30522, 'emb_dim': 256, 'hid_dim': 512, 'enc_num_layers': 2, 'dec_num_layers': 2, 'encoder_pad_idx': 70, 'decoder_pad_idx': 0, 'batch_size': 512, 'lr_enc': 0.01, 'lr_dec': 0.01, 'optim_enc': 'Adam', 'optim_dec': 'Adam', 'optim_enc_betas': (0.9, 0.999), 'optim_dec_betas': (0.9, 0.999), 'criterion': 'CrossEntropyLoss', 'dataset_len': 118287, 'dataset_features': ['coco_url', 'captions', 'image_id', 'features', 'message', 'message_truncated'], 'num_epochs': 10, 'tgt_tokenizer': 'bert-base-uncased'}


[34m[1mwandb[0m: Currently logged in as: [33meignatenko[0m ([33mnipg-elte[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [13]:
import torch 

In [14]:
@torch.no_grad()
def evaluate(encoder, decoder, loader, criterion, device):
    encoder.eval()
    decoder.eval()

    total_loss = 0.0
    for src, tgt in loader:
        src, tgt = src.to(device), tgt.to(device)

        h, c = encoder(src)
        logits, _, _ = decoder(tgt[:, :-1], h, c)

        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            tgt[:, 1:].reshape(-1)
        )
        total_loss += loss.item()

    return total_loss / len(loader)


In [15]:
patience = 10
best_val_loss = float("inf")
patience_ctr = 0


In [16]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

sched_enc = ReduceLROnPlateau(enc_opt, mode="min", factor=0.9, patience=20)
sched_dec = ReduceLROnPlateau(dec_opt, mode="min", factor=0.9, patience=20)


In [18]:
def evaluate(encoder, decoder, loader, criterion, device):
    encoder.eval()
    decoder.eval()

    total_loss = 0.0
    refs, hyps = [], []

    with torch.no_grad():
        for src, tgt in loader:
            src, tgt = src.to(device), tgt.to(device)

            h, c = encoder(src)
            logits, _, _ = decoder(tgt[:, :-1], h, c)

            loss = criterion(
                logits.reshape(-1, logits.size(-1)),
                tgt[:, 1:].reshape(-1)
            )
            total_loss += loss.item()

            pred = logits.argmax(-1)
            hyps.extend(pred.tolist())
            refs.extend(tgt[:, 1:].tolist())

    bleu = sacrebleu.corpus_bleu(
        [" ".join(map(str, h)) for h in hyps],
        [[" ".join(map(str, r)) for r in refs]]
    ).score

    return total_loss / len(loader), bleu


In [19]:
from tqdm import tqdm

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()

    total_loss = 0.0

    for src, tgt in tqdm(loader, desc=f"epoch {epoch}"):
        src, tgt = src.to(device), tgt.to(device)

        enc_opt.zero_grad()
        dec_opt.zero_grad()

        h, c = encoder(src)
        logits, _, _ = decoder(tgt[:, :-1], h, c)

        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            tgt[:, 1:].reshape(-1)
        )

        loss.backward()
        enc_opt.step()
        dec_opt.step()
        # sched_enc.step(lo
        # ss)


        total_loss += loss.item()
        wandb.log(
            {"train/batch_loss": loss.item(),
            "lr/encoder": enc_opt.param_groups[0]["lr"],
            "lr/decoder": dec_opt.param_groups[0]["lr"],
        })


    train_loss = total_loss / len(loader)
    # val_loss = evaluate(encoder, decoder, val_loader, criterion, device)
    val_loss, val_bleu = evaluate(
    encoder, decoder, val_loader, criterion, device
    )
    sched_enc.step(val_loss)
    sched_dec.step(val_loss)

    # wandb.log({
    #     "epoch": epoch,
    #     "train/loss": train_loss,
    #     "val/loss": val_loss,
    # })
    wandb.log({
    "epoch": epoch,
    "train/loss": train_loss,
    "val/loss": val_loss,
    "val/bleu": val_bleu,
    })

    print(
        f"epoch {epoch}: "
        f"train_loss={train_loss:.4f} | val_loss={val_loss:.4f}"
    )

    # -------- early stopping --------
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_ctr = 0

        # optional: save best model
        torch.save({
            "encoder": encoder.state_dict(),
            "decoder": decoder.state_dict(),
        }, "best_model.pt")

        wandb.log({"early_stop/best_val_loss": best_val_loss})
    else:
        patience_ctr += 1
        wandb.log({"early_stop/patience": patience_ctr})

        if patience_ctr >= patience:
            print(f"Early stopping at epoch {epoch}")
            break


epoch 0: 100%|██████████| 232/232 [02:18<00:00,  1.68it/s]


NameError: name 'sacrebleu' is not defined

In [None]:
src_msg = torch.tensor([67, 44,  5, 23, 59, 65, 60, 61, 19, 14, 42, 67, 17, 10, 68, 18, 20, 43,
          0])

In [None]:
src = [src_msg]
src_padded = pad_sequence(
    src,
    batch_first=True,
    padding_value=PAD_ID
).to(device)
src_padded

tensor([[67, 44,  5, 23, 59, 65, 60, 61, 19, 14, 42, 67, 17, 10, 68, 18, 20, 43,
          0]], device='cuda:0')

In [None]:
def greedy_decode(encoder, decoder, src, max_len=50, device="cpu"):
    """
    src: [1, T_src] tensor, already padded
    returns: list of token IDs
    """
    encoder.eval()
    decoder.eval()

    sos_id = tgt_tokenizer.cls_token_id
    eos_id = tgt_tokenizer.sep_token_id

    with torch.no_grad():
        h, c = encoder(src.to(device))

        # first input to decoder
        tgt_id = torch.tensor([[sos_id]], device=device)
        output_ids = []

        for _ in range(max_len):
            logits, h, c = decoder(tgt_id, h, c)           # [1, 1, vocab_size]
            next_id = logits[:, -1, :].argmax(dim=-1)      # [1]
            next_id_item = next_id.item()

            if next_id_item == eos_id:
                break

            output_ids.append(next_id_item)
            tgt_id = next_id.unsqueeze(0)                   # feed predicted token

    return output_ids

def decode_tokens(ids):
    return tgt_tokenizer.decode(ids, skip_special_tokens=True)


In [None]:
src_ind = [62, 42, 31, 63, 22, 60, 38, 13, 62, 56,  4, 37, 12, 35, 58, 59, 57, 65,
         30, 62, 26, 51, 17, 24,  6, 37, 63, 50, 29, 40, 25, 50, 39,  9, 33, 19,
         47, 11,  8,  0]
src_example = torch.tensor([src_ind], device=device)
translated_ids = greedy_decode(encoder, decoder, src_example, max_len=50, device=device)
translated_text = decode_tokens(translated_ids)

print("Translation:", translated_text)


Translation: a man riding a wave on top of a surfboard.
