In [1]:
import os
from os.path import exists
import torch
from torch.nn.functional import log_softmax, pad
import time
from torch.optim.lr_scheduler import LambdaLR
from torchtext.data.functional import to_map_style_dataset
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
import torchtext.datasets as datasets
import spacy
import GPUtil
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

In [2]:
from make_model import make_model
from Processing.batch import Batch
from Processing.labelsmoothing import LabelSmoothing
from config import Config

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
class TrainState:
    """Track number of steps, examples, and tokens processed"""

    step: int = 0  # Steps in the current epoch
    accum_step: int = 0  # Number of gradient accumulation steps
    samples: int = 0  # total # of examples used
    tokens: int = 0  # total # of tokens processed

In [5]:
def run_epoch(
    data_iter,
    model,
    loss_compute,
    optimizer,
    scheduler,
    mode="train",
    accum_iter=1,
    train_state=TrainState(),
):
    """Train a single epoch"""
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    n_accum = 0
    for i, batch in enumerate(data_iter):
        out = model(batch.src, batch.tgt, batch.src_mask, batch.tgt_mask)
        loss, loss_node = loss_compute(out, batch.tgt_y, batch.ntokens)
        # loss_node = loss_node / accum_iter
        
        if mode == "train" or mode == "train+log":
            loss_node.backward()
            train_state.step += 1
            train_state.samples += batch.src.shape[0]
            train_state.tokens += batch.ntokens
            if i % accum_iter == 0:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                n_accum += 1
                train_state.accum_step += 1
            scheduler.step()

        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        
        if i % 40 == 0 and (mode == "train" or mode == "train+log"):
            lr = optimizer.param_groups[0]["lr"]
            elapsed = time.time() - start
            print(("Epoch Step: %6d | Accumulation Step: %3d | Loss: %6.2f | Tokens / Sec: %7.1f | Learning Rate: %6.1e")
                % (i, n_accum, loss / batch.ntokens, tokens / elapsed, lr))
            start = time.time()
            tokens = 0
        del loss
        del loss_node
    return (total_loss / total_tokens), train_state

In [6]:
def rate(step, model_size, factor, warmup):
    """
    we have to default the step to 1 for LambdaLR function
    to avoid zero raising to negative power.
    """
    if step == 0:
        step = 1
    return factor * (model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5)))

In [7]:
class SimpleLossCompute:
    "A simple loss compute and train function."

    def __init__(self, projection, criterion):
        self.proj = projection
        self.criterion = criterion

    def __call__(self, x, y, norm):
        x = log_softmax(self.proj(x), dim=-1)
        sloss = (self.criterion(x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1)) / norm)
        return sloss.data * norm, sloss

In [8]:
class DummyOptimizer(torch.optim.Optimizer):
    def __init__(self):
        self.param_groups = [{"lr": 0}]
        None

    def step(self):
        None

    def zero_grad(self, set_to_none=False):
        None
        
class DummyScheduler:
    def step(self):
        None

In [9]:
def load_tokenizers():
    try:
        spacy_de = spacy.load("de_core_news_sm")
    except IOError:
        os.system("python -m spacy download de_core_news_sm")
        spacy_de = spacy.load("de_core_news_sm")

    try:
        spacy_en = spacy.load("en_core_web_sm")
    except IOError:
        os.system("python -m spacy download en_core_web_sm")
        spacy_en = spacy.load("en_core_web_sm")

    return spacy_de, spacy_en

In [10]:
spacy_de, spacy_en = load_tokenizers()

In [11]:
def tokenize(text, tokenizer):
    return [tok.text for tok in tokenizer.tokenizer(text)]


def yield_tokens(data_iter, tokenizer, index):
    for from_to_tuple in data_iter:
        yield tokenizer(from_to_tuple[index])

In [12]:
def tokenize_de(text):
    return tokenize(text, spacy_de)

def tokenize_en(text):
    return tokenize(text, spacy_en)

In [13]:
def build_vocabulary():
    print("Building German Vocabulary ...")
    train, val, test = datasets.Multi30k(language_pair=("de", "en"))
    vocab_src = build_vocab_from_iterator(
        yield_tokens(train + val, tokenize_de, index=0),
        min_freq=2,
        specials=["<s>", "</s>", "<blank>", "<unk>"],
    )

    print("Building English Vocabulary ...")
    train, val, test = datasets.Multi30k(language_pair=("de", "en"))
    vocab_tgt = build_vocab_from_iterator(
        yield_tokens(train + val, tokenize_en, index=1),
        min_freq=2,
        specials=["<s>", "</s>", "<blank>", "<unk>"],
    )

    vocab_src.set_default_index(vocab_src["<unk>"])
    vocab_tgt.set_default_index(vocab_tgt["<unk>"])
    
    del train
    del val
    del test
    
    return vocab_src, vocab_tgt

def load_vocab():
    if not exists("vocab.pt"):
        vocab_src, vocab_tgt = build_vocabulary()
        torch.save((vocab_src, vocab_tgt), "vocab.pt")
    else:
        vocab_src, vocab_tgt = torch.load("vocab.pt")
    print("Finished.\nVocabulary sizes:")
    print(len(vocab_src))
    print(len(vocab_tgt))
    return vocab_src, vocab_tgt


In [14]:
vocab_src, vocab_tgt = load_vocab()

Finished.
Vocabulary sizes:
8185
6291


In [15]:
def collate_batch(
    batch,
    src_pipeline,
    tgt_pipeline,
    src_vocab,
    tgt_vocab,
    device='cuda',
    max_padding=128,
    pad_id=2,
):
    bs_id = torch.tensor([0], device=device)  # <s> token id
    eos_id = torch.tensor([1], device=device)  # </s> token id
    src_list, tgt_list = [], []
    for (_src, _tgt) in batch:
        processed_src = torch.cat(
            [
                bs_id,
                torch.tensor(
                    src_vocab(src_pipeline(_src)),
                    dtype=torch.int64,
                    device=device,
                ),
                eos_id,
            ],
            0,
        )
        processed_tgt = torch.cat(
            [
                bs_id,
                torch.tensor(
                    tgt_vocab(tgt_pipeline(_tgt)),
                    dtype=torch.int64,
                    device=device,
                ),
                eos_id,
            ],
            0,
        )
        src_list.append(
            # warning - overwrites values for negative values of padding - len
            pad(
                processed_src,
                (
                    0,
                    max_padding - len(processed_src),
                ),
                value=pad_id,
            )
        )
        tgt_list.append(
            pad(
                processed_tgt,
                (0, max_padding - len(processed_tgt)),
                value=pad_id,
            )
        )

    src = torch.stack(src_list,)
    tgt = torch.stack(tgt_list)
    return src, tgt

In [16]:
def create_dataloaders(
    vocab_src,
    vocab_tgt,
    batch_size=12000,
    max_padding=128,
    is_distributed=True,
    device='cuda'
):
    # def create_dataloaders(batch_size=12000):
    def collate_fn(batch):
        return collate_batch(
            batch,
            tokenize_de,
            tokenize_en,
            vocab_src,
            vocab_tgt,
            device,
            max_padding=max_padding,
            pad_id=vocab_src.get_stoi()["<blank>"],
        )

    train_iter, valid_iter, _ = datasets.Multi30k(language_pair=("de", "en"))

    train_iter_map = to_map_style_dataset(train_iter)  # DistributedSampler needs a dataset len()
    train_sampler = (DistributedSampler(train_iter_map) if is_distributed else None)
    valid_iter_map = to_map_style_dataset(valid_iter)
    valid_sampler = (DistributedSampler(valid_iter_map) if is_distributed else None)

    train_dataloader = DataLoader(
        train_iter_map,
        batch_size=batch_size,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        collate_fn=collate_fn,
        generator=torch.Generator('cpu')
    )
    valid_dataloader = DataLoader(
        valid_iter_map,
        batch_size=batch_size,
        shuffle=(valid_sampler is None),
        sampler=valid_sampler,
        collate_fn=collate_fn,
        generator=torch.Generator('cpu')
    )
    return train_dataloader, valid_dataloader

In [17]:
def train_worker(
    gpu,
    ngpus_per_node,
    vocab_src,
    vocab_tgt,
    config,
    is_distributed=False,
):
    print(f"Train worker process using cuda:{gpu} for training", flush=True)
    pad_idx = vocab_tgt["<blank>"]
    d_model = 512
    model = make_model(config)
    is_main_process = True
    
    # if is_distributed:
    #     dist.init_process_group("nccl", init_method="env://", rank=gpu, world_size=ngpus_per_node)
    #     model = DDP(model, device_ids=[gpu])
    #     module = model.module
    #     is_main_process = gpu == 0

    criterion = LabelSmoothing(size=len(vocab_tgt), padding_idx=pad_idx, smoothing=0.1)
    criterion.to(device)

    train_dataloader, valid_dataloader = create_dataloaders(
        vocab_src,
        vocab_tgt,
        batch_size=config.batch_size // ngpus_per_node,
        max_padding=config.max_padding,
        is_distributed=is_distributed,
        device='cuda'
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=config.base_lr, betas=(0.9, 0.98), eps=1e-9)
    lr_scheduler = LambdaLR(
        optimizer=optimizer,
        lr_lambda=lambda step: rate(
            step, d_model, factor=1, warmup=config.warmup
        ),
    )
    train_state = TrainState()

    def gen_batch(dataloader):
        for b0, b1 in dataloader:
            yield Batch(b0.to(device), b1.to(device), pad_idx)

    for epoch in range(config.num_epochs):
        # if is_distributed:
        #     train_dataloader.sampler.set_epoch(epoch)
        #     valid_dataloader.sampler.set_epoch(epoch)

        model.train()
        print(f"Epoch {epoch} Training ====", flush=True)
        _, train_state = run_epoch(
            gen_batch(train_dataloader),
            model,
            SimpleLossCompute(model.proj, criterion),
            optimizer,
            lr_scheduler,
            mode="train+log",
            accum_iter=config.accum_iter,
            train_state=train_state,
        )

        GPUtil.showUtilization()
        if is_main_process:
            file_path = "%s%.2d.pt" % (config.file_prefix, epoch)
            torch.save(model.state_dict(), file_path)
        torch.cuda.empty_cache()

        print(f"Epoch {epoch} Validation ====", flush=True)
        model.eval()
        sloss, _ = run_epoch(
            gen_batch(valid_dataloader),
            model,
            SimpleLossCompute(model.proj, criterion),
            DummyOptimizer(),
            DummyScheduler(),
            mode="eval", 
        )
        print(sloss)
        torch.cuda.empty_cache()

    if is_main_process:
        file_path = "%sfinal.pt" % config.file_prefix
        torch.save(model.state_dict(), file_path)

# def train_distributed_model(vocab_src, vocab_tgt, config):
#     ngpus = torch.cuda.device_count()
#     os.environ["MASTER_ADDR"] = "localhost"
#     os.environ["MASTER_PORT"] = "12356"
#     print(f"Number of GPUs detected: {ngpus}")
#     print("Spawning training processes ...")
#     mp.spawn(train_worker, nprocs=ngpus, args=(ngpus, vocab_src, vocab_tgt, config, True))

def train_model(vocab_src, vocab_tgt, config):
    # if config.distributed:
    #     train_distributed_model(vocab_src, vocab_tgt, config)
    # else:
        train_worker(0, 1, vocab_src, vocab_tgt, config, False)

In [31]:
config = Config()
config.batch_size = 32
config.src_vocab = len(vocab_src)
config.tgt_vocab = len(vocab_tgt)
config.N = 6
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [18]:
def load_trained_model():
    model_path = "multi30k_model_final.pt"
    if not exists(model_path):
        train_model(vocab_src, vocab_tgt, config)
    model = make_model(config).to(device)
    model.load_state_dict(torch.load("multi30k_model_final.pt"))
    return model

In [19]:
model = load_trained_model()

Train worker process using cuda:0 for training
Epoch 0 Training ====
Epoch Step:      0 | Accumulation Step:   1 | Loss:   7.64 | Tokens / Sec:   272.9 | Learning Rate: 2.7e-07
Epoch Step:     40 | Accumulation Step:   5 | Loss:   7.54 | Tokens / Sec:   958.7 | Learning Rate: 1.1e-05
Epoch Step:     80 | Accumulation Step:   9 | Loss:   7.12 | Tokens / Sec:   961.2 | Learning Rate: 2.2e-05
Epoch Step:    120 | Accumulation Step:  13 | Loss:   6.74 | Tokens / Sec:   967.5 | Learning Rate: 3.3e-05
Epoch Step:    160 | Accumulation Step:  17 | Loss:   6.53 | Tokens / Sec:   959.7 | Learning Rate: 4.3e-05
Epoch Step:    200 | Accumulation Step:  21 | Loss:   6.30 | Tokens / Sec:   959.1 | Learning Rate: 5.4e-05
Epoch Step:    240 | Accumulation Step:  25 | Loss:   6.09 | Tokens / Sec:   967.6 | Learning Rate: 6.5e-05
Epoch Step:    280 | Accumulation Step:  29 | Loss:   6.09 | Tokens / Sec:   980.7 | Learning Rate: 7.6e-05
Epoch Step:    320 | Accumulation Step:  33 | Loss:   5.78 | Tokens

In [52]:
def average(model, models):
    "Average models into model"
    print(isinstance(model, torch.nn.Module))  # Should print True
    print(all(isinstance(m, torch.nn.Module) for m in models))  # Check all models
    for p_target, *p_sources in zip(model.parameters(), *[m.parameters() for m in models]):
        p_target.data.copy_(torch.sum(torch.stack(p_sources), dim=0) / len(p_sources))

In [58]:
model = make_model(config).to(device)
model.load_state_dict(torch.load("multi30k_model_final.pt"))
models = [make_model(config).to(device) for _ in range(3)]

# Load state dictionaries separately
for version in range(3):
    models[version].load_state_dict(torch.load(f"multi30k_model_0{7-version}.pt"))

average(model, models)
torch.save(model.state_dict(), "multi30k_model_final_avg.pt")

True
True


In [55]:
def check_outputs(
    valid_dataloader,
    model,
    vocab_src,
    vocab_tgt,
    n_examples=15,
    pad_idx=2,
    eos_string="</s>",
):
    results = [()] * n_examples
    for idx in range(n_examples):
        print("\nExample %d ========\n" % idx)
        b = next(iter(valid_dataloader))
        rb = Batch(b[0], b[1], pad_idx)
        model.generate(rb.src, rb.src_mask, 64, 0)[0]

        src_tokens = [
            vocab_src.get_itos()[x] for x in rb.src[0] if x != pad_idx
        ]
        tgt_tokens = [
            vocab_tgt.get_itos()[x] for x in rb.tgt[0] if x != pad_idx
        ]

        print(
            "Source Text (Input)        : "
            + " ".join(src_tokens).replace("\n", "")
        )
        print(
            "Target Text (Ground Truth) : "
            + " ".join(tgt_tokens).replace("\n", "")
        )
        model_out = model.generate(rb.src, rb.src_mask, 72, 0)[0]
        model_txt = (
            " ".join(
                [vocab_tgt.get_itos()[x] for x in model_out if x != pad_idx]
            ).split(eos_string, 1)[0]
            + eos_string
        )
        print("Model Output               : " + model_txt.replace("\n", ""))
        results[idx] = (rb, src_tokens, tgt_tokens, model_out, model_txt)
    return results

In [59]:
def run_model_example(config, n_examples=5):
    global vocab_src, vocab_tgt, spacy_de, spacy_en

    print("Preparing Data ...")
    _, valid_dataloader = create_dataloaders(
        vocab_src,
        vocab_tgt,
        batch_size=1,
        is_distributed=False,
        device='cuda'
    )

    print("Loading Trained Model ...")

    model = make_model(config)
    model.load_state_dict(torch.load("multi30k_model_final_avg.pt", map_location=torch.device("cpu")))

    print("Checking Model Outputs:")
    example_data = check_outputs(valid_dataloader, model, vocab_src, vocab_tgt, n_examples=n_examples)
    return model, example_data

In [60]:
config = Config()
config.batch_size = 32
config.src_vocab = len(vocab_src)
config.tgt_vocab = len(vocab_tgt)
config.N = 6
run_model_example(config)

Preparing Data ...
Loading Trained Model ...
Checking Model Outputs:


Source Text (Input)        : <s> Ein Mann an einem Strand baut eine Sandburg . </s>
Target Text (Ground Truth) : <s> A man on a beach building a sand castle . </s>
Model Output               : <s> A man on a beach building a sand castle . </s>


Source Text (Input)        : <s> Ein Mann mit beginnender Glatze , der eine rote Rettungsweste trägt , sitzt in einem kleinen Boot . </s>
Target Text (Ground Truth) : <s> A balding man wearing a red life jacket is sitting in a small boat . </s>
Model Output               : <s> A balding man wearing a red life jacket sitting in a small boat . </s>


Source Text (Input)        : <s> Ein Mann führt zwei kleine <unk> in einem Park spazieren . </s>
Target Text (Ground Truth) : <s> A man is leading two small <unk> on a walk at a park . </s>
Model Output               : <s> A man walks two small <unk> in a park . </s>


Source Text (Input)        : <s> Zwei Männer hinter einer krei

(EncoderDecoder(
   (encoder): Encoder(
     (layers): ModuleList(
       (0-5): 6 x EncoderLayer(
         (self_attn): MultiHeadedAttention(
           (w_q): Linear(in_features=512, out_features=512, bias=True)
           (w_k): Linear(in_features=512, out_features=512, bias=True)
           (w_v): Linear(in_features=512, out_features=512, bias=True)
           (proj): Linear(in_features=512, out_features=512, bias=True)
           (dropout): Dropout(p=0.1, inplace=False)
         )
         (feed_forward): FeedForward(
           (w_1): Linear(in_features=512, out_features=2048, bias=True)
           (w_2): Linear(in_features=2048, out_features=512, bias=True)
           (dropout): Dropout(p=0.1, inplace=False)
         )
         (sublayer): ModuleList(
           (0-1): 2 x SublayerConnection(
             (norm): LayerNorm()
             (dropout): Dropout(p=0.1, inplace=False)
           )
         )
       )
     )
     (norm): LayerNorm()
   )
   (decoder): Decoder(
     (lay