In [None]:
!pip install -q tf-estimator-nightly==2.8.0.dev2021122109 earthengine-api==0.1.238 folium==0.2.1
!pip install -q torchtext==0.11.0 torchaudio==0.10.0 torchvision==0.11.1 torch==1.10
!pip install -q cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/colab/1.10/torch_xla-1.10-cp37-cp37m-linux_x86_64.whl
!pip install -q transformers pytorch_lightning datasets pyngrok

# !wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.tgz
# !tar zxvf ngrok-stable-linux-amd64.tgz
!./ngrok authtoken 1y2MQMr0xLh05Dvbb0dABiNQpAY_3bqEfwwEtM7duDaqwrN93

/bin/bash: ./ngrok: No such file or directory


In [None]:
from transformers.models.bart.tokenization_bart_fast import BartTokenizerFast
from datasets import load_dataset
import torch.utils.data as data
import pandas as pd
import os
from tensorboard import program
from pyngrok import ngrok
from pytorch_lightning import LightningModule, LightningDataModule, Trainer
from transformers.models.bart.modeling_bart import BartForConditionalGeneration
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torch



In [None]:
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
    """
    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
        )

    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

In [None]:
def start_tensorboard(log_dir: str):
    tb = program.TensorBoard()
    tb.configure(argv=[None, '--logdir', log_dir])
    url = tb.launch()
    print(f"Tensorflow listening on {url}")
    port = int(url.split(":")[-1][:-1])
    print(ngrok.connect(port))

In [None]:
class Dataset(data.Dataset):
    tokenizer = BartTokenizerFast.from_pretrained("sshleifer/distilbart-xsum-12-1")

    def __init__(self, split: str = "train"):
     
        self.data = pd.read_csv(f"Data/{split}.csv").drop("id", axis=1)
        self.data = self.data.to_numpy()

    def __getitem__(self, idx):
        document, summary = self.data[idx]
        input_ids, attention_mask =  tuple(self.tokenizer(document, padding="max_length", truncation=True, return_tensors="pt").values())
        decoder_input_ids, decoder_attention_mask =  tuple(self.tokenizer(summary, padding="max_length", truncation=True, return_tensors="pt").values())
        return (input_ids.squeeze(0), attention_mask.squeeze(0)), (decoder_input_ids.squeeze(0), decoder_attention_mask.squeeze(0))

    def __len__(self):
        return len(self.data)

In [None]:
class DataModule(LightningDataModule):
    def __init__(self, batch_size: int) -> None:
        super().__init__()
        self.batch_size = batch_size

        self.train_dataset = None
        self.test_dataset = None
        self.val_dataset = None

    def prepare_data(self) -> None:
        if not os.path.exists("Data"):
            os.mkdir("Data")
            datasets = load_dataset("xsum", name="bart-base")

            for split in ("train", "test", "validation"):
                datasets[split].to_csv(f"Data/{split}.csv", index=False)
                pd.read_csv(f"Data/{split}.csv").dropna().to_csv(f"Data/{split}.csv", index=False)


    def setup(self, stage: str = None) -> None:
        if (stage == "fit" or stage == None) and ((not self.train_dataset) or (not self.val_dataset)):
            self.train_dataset = Dataset(split="train")
            self.val_dataset = Dataset(split="validation")
       
        if (stage == "test" or stage == None) and (not self.test_dataset):
            self.test_dataset = Dataset(split="test")

    def train_dataloader(self):
        return data.DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return data.DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return data.DataLoader(self.test_dataset, batch_size=self.batch_size)

In [None]:
class Model(LightningModule):
    def __init__(self):
        super(Model, self).__init__()
        self.model = BartForConditionalGeneration.from_pretrained("sshleifer/distilbart-xsum-12-1")

    @staticmethod
    def shift_right(tensor: torch.Tensor, start_token: int):
        shifted = torch.zeros_like(tensor)

        shifted[:, 1:] = tensor[:, :-1].clone()
        shifted[:, 0] = start_token
        return shifted
        
    def forward(self, batch: tuple):
        (input_ids, attention_mask), (decoder_ids, decoder_attention_mask) = batch
        # decoder_ids.shape: (batch_size, seq_len)
        decoder_inps = self.shift_right(decoder_ids, Dataset.tokenizer.bos_token_id)
        
        decoder_attention_mask = self.shift_right(decoder_attention_mask, 1)
        logits: torch.Tensor = self.model(input_ids, attention_mask, decoder_inps, decoder_attention_mask).logits
        # logits.shape: (batch_size, seq_len, vocab_size)
        # logits obviously means no activation
        return logits

In [None]:
class TrainModel(Model):
    def __init__(self, learning_rate: float, ultimate_batch_size: int, epochs: int, label_smoothing=0.0, ignore_index=None):
        super(TrainModel, self).__init__()
        self.learning_rate = learning_rate

        self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing, ignore_index=ignore_index)

        steps_per_iter = 215344 # len(train_dataset) + len(val_dataset)
        self.num_training_steps = (steps_per_iter // ultimate_batch_size) * epochs
        self.num_warmup_steps = int(self.num_training_steps * 0.1)


    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = get_linear_schedule_with_warmup(optimizer, self.num_warmup_steps, self.num_training_steps)
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]

    @staticmethod
    @torch.no_grad()
    def calculate_accuracy(logits: torch.Tensor, target: torch.Tensor):
        # logits.shape: (batch_size, seq_len, vocab_size)
        # decoder_ids.shape: (batch_size, seq_len)
        predictions: torch.Tensor = F.softmax(logits, dim=2).argmax(axis=2)
        # predictions.shape: (batch_size, seq_len)
        accuracy = (target.view(-1) == predictions.view(-1)).sum() / (predictions.size(0) * predictions.size(1))
        return accuracy

    def forward_step(self, batch: tuple):
        _, (decoder_ids, _) = batch
        # decoder_ids.shape: (batch_size, seq_len)

        logits = self.forward(batch)
        # logits.shape: (batch_size, seq_len, vocab_size)
        loss: torch.Tensor = self.criterion(logits.reshape(logits.size(0) * logits.size(1), logits.size(2)), decoder_ids.reshape(decoder_ids.size(0) * decoder_ids.size(1)))
        accu = self.calculate_accuracy(logits, decoder_ids)
        return loss, accu.item()

    def training_step(self, batch: tuple, batch_idx: int):
        loss, accu = self.forward_step(batch)
        self.log("lr", self.lr_schedulers().get_last_lr()[0], prog_bar=True)
        self.log("loss", loss.item())
        self.log("accu", accu, prog_bar=True)
        # print(f"train step done Loss={loss.item()} !")
        return loss

    def validation_step(self, batch: tuple, batch_idx: int):
        loss, accu = self.forward_step(batch)
        self.log("val_loss", loss.item(), prog_bar=True)
        self.log("val_accu", accu, prog_bar=True)
        return loss
    
    def test_step(self, batch: tuple, batch_idx: int):
        loss, accu = self.forward_step(batch)
        self.log("test_loss", loss.item(), prog_bar=True)
        self.log("test_accu", accu, prog_bar=True)
        return loss

In [None]:
epochs = 1
batch_size = 1
learning_rate = 1e-6
tpu_cores = 8


trainer = Trainer(max_epochs=epochs, tpu_cores=tpu_cores)
model = TrainModel(learning_rate, batch_size*tpu_cores, epochs, label_smoothing=0.1, ignore_index=Dataset.tokenizer.pad_token_id)
datamodule = DataModule(batch_size)

GPU available: False, used: False
TPU available: True, using: 8 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
# trainer.fit(model, datamodule)

In [None]:
datamodule.setup()
batch = next(iter(datamodule.train_dataloader()))

In [None]:
out = model(batch)

In [None]:
out.shape

torch.Size([1, 1024, 50264])

In [None]:
batch[0][0]

tensor([[  0, 133, 455,  ...,   1,   1,   1]])

In [None]:
my_vocab = {i: k for k, i in Dataset.tokenizer.vocab.items()}

In [None]:
str('ĠworstĠ').strip("Ġ")

'worst'

In [None]:
' '.join([my_vocab[i].strip("Ġ").strip("Ċ") for i in batch[0][0].squeeze(0).tolist()])



In [None]:
batch[1][0].shape

torch.Size([1, 1024])

In [None]:
pred_sum = out.argmax(-1).squeeze(0)
true_sum = batch[1][0].squeeze(0)

In [None]:
(pred_sum[true_sum != 1] == true_sum[true_sum != 1]).sum()

tensor(12)

In [None]:
pred_sum[true_sum != 1].shape

torch.Size([26])

In [None]:
true_sum

tensor([    0, 40827,    12,  ...,     1,     1,     1])