# Imports

In [36]:
import pandas as pd
import numpy as np

# PyTorch related imports
import torch
from torch.nn import *
from torch.nn import functional as F

# PyTorch Lightning related imports
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# metrics
from torchmetrics.functional import mean_absolute_error, mean_squared_error
from metrics.regression_metrics import t_kendalltau, t_pearson, t_spearman

# fnet architecture & RAdam
from fnet.model import FNet
from radam.radam import RAdam

# Custom datamodule
from datamodules.tm_datamodule import TextMiningDataModule

# weights and biases // not working // disabled
import wandb
from pytorch_lightning.loggers import WandbLogger

# ModelCheckpoint fails. Tutorial outdated. 
# Leaving here for improvements in the future
from pytorch_lightning.callbacks import ModelCheckpoint



# Dataset & DataLoader

In [23]:
srt = ["source", "reference", "translation"]
language_pairs = [
    "cs-en",
    "de-en",
    "en-fi",
    "en-zh",
    "ru-en",
    "zh-en",
]
scores = {pair: pd.read_csv(f"corpus/{pair}/scores.csv") for pair in language_pairs}

In [25]:
pair = "de-en"
embedding_ref = torch.from_numpy(np.load(f"corpus/{pair}/laser.reference_embeds.npy")).float()
embedding_src = torch.from_numpy(np.load(f"corpus/{pair}/laser.source_embeds.npy")).float()
embedding_hyp = torch.from_numpy(np.load(f"corpus/{pair}/laser.translation_embeds.npy")).float()
score = torch.tensor(scores[pair]["z-score"]).float()

In [26]:
score.shape

torch.Size([21704])

# Loggers

In [7]:
early_stop_callback = EarlyStopping(monitor='val_loss', patience=5, verbose=False, mode='min')

In [8]:
class TranslationPredictionLogger(Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_score = val_samples

    def on_validation_batch_end(
        self, trainer, pl_module,
    ):
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_score = self.val_score.to(device=pl_module.device)

        predictions = pl_module(val_imgs)
        trainer.logger.experiment.log(
            {
                "examples": [
                    wandb.Image(x, caption=f"Prediction:{p}, Score: {y}")
                    for x, p, y in zip(
                        val_imgs[: self.num_samples],
                        predictions[: self.num_samples],
                        val_score[: self.num_samples],
                    )
                ]
            }
        )


# Modules

## Convolutional

In [9]:
FEATURES_1 = 64
FEATURES_2 = FEATURES_1 * 2
FEATURES_3 = FEATURES_2 * 2
FEATURES_4 = FEATURES_3 * 2
FEATURES_5 = FEATURES_4 * 2

class Model(pl.LightningModule):
    def __init__(self, input_shape, learning_rate=0.001):
        super().__init__()

        self.save_hyperparameters()
        self.learning_rate = learning_rate

        self.c1 = Conv2d(3, FEATURES_1, (3, 1), 1, (1, 0))
        self.c2 = Conv2d(FEATURES_1, FEATURES_1, (3, 1), 1, (1, 0)) # 32
        self.c3 = Conv2d(FEATURES_1, FEATURES_2, (3, 1), 4, (1, 0)) # 24
        self.c4 = Conv2d(FEATURES_2, FEATURES_3, (3, 1), 2, (1, 0)) # 16
        self.c5 = Conv2d(FEATURES_3, FEATURES_4, (3, 1), 2, (1, 0)) # 8
        self.c6 = Conv2d(FEATURES_4, FEATURES_5, (3, 1), 2, (1, 0)) # 2
        self.bn = BatchNorm2d(FEATURES_1)
        self.bn3 = BatchNorm2d(FEATURES_2)
        self.bn4 = BatchNorm2d(FEATURES_3)
        self.bn5 = BatchNorm2d(FEATURES_4)
        self.bn6 = BatchNorm2d(FEATURES_5)
        self.fc1 = Linear(32 * FEATURES_5, 256)
        self.fc2 = Linear(256, 64)
        self.fc3 = Linear(64, 1)

    def _forward_features(self, x):
        x = self.bn(F.relu(self.c1(x)))
        # x = self.bn(F.relu(self.c2(x) + x))
        # x = self.bn(F.relu(self.c2(x) + x))
        # x = self.bn(F.relu(self.c2(x) + x))
        x = self.bn(F.relu(self.c2(x)))
        x = self.bn3(F.relu(self.c3(x)))
        x = self.bn4(F.relu(self.c4(x)))
        x = self.bn5(F.relu(self.c5(x)))
        # x = self.bn6(F.relu(self.c6(x)))
        return x

    def forward(self, x):
        x = self._forward_features(x)
        # print(x.shape)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.mse_loss(logits, y)
        mse = mean_squared_error(logits, y)
        mae = mean_absolute_error(logits, y)
        k = t_kendalltau(logits, y)
        p = t_pearson(logits, y)
        s = t_spearman(logits, y)

        self.log("train_loss", loss, on_step=True, on_epoch=True, logger=True)
        self.log("train_mse", mse, on_step=True, on_epoch=True, logger=True)
        self.log("train_mae", mae, on_step=True, on_epoch=True, logger=True)
        self.log("train_kendalltau", k, on_step=False, on_epoch=True, logger=True)
        self.log("train_pearson", p, on_step=False, on_epoch=True, logger=True)
        self.log("train_spearman", s, on_step=False, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.mse_loss(logits, y)
        mse = mean_squared_error(logits, y)
        mae = mean_absolute_error(logits, y)
        k = t_kendalltau(logits, y)
        p = t_pearson(logits, y)
        s = t_spearman(logits, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_mse", mse, prog_bar=True)
        self.log("val_mae", mae, prog_bar=True)
        self.log("val_kendalltau", k, on_step=False, prog_bar=True)
        self.log("val_pearson", p, on_step=False, prog_bar=True)
        self.log("val_spearman", s, on_step=False, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.mse_loss(logits, y)
        mse = mean_squared_error(logits, y)
        mae = mean_absolute_error(logits, y)
        k = t_kendalltau(logits, y)
        p = t_pearson(logits, y)
        s = t_spearman(logits, y)
        self.log("test_loss", loss, on_step=True, prog_bar=True)
        self.log("test_mse", mse, on_step=True, prog_bar=True)
        self.log("test_mae", mae, on_step=True, prog_bar=True)
        self.log("test_kendalltau", k, prog_bar=True)
        self.log("test_pearson", p, prog_bar=True)
        self.log("test_spearman", s, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer


## FourierTransformerDecoder

In [42]:
class Model(pl.LightningModule):
    def __init__(
        self,
        input_shape,
        learning_rate=0.001,
        num_layers: int = 6,
        dropout: float = 0.1,
        dim_ff:int = 2048
    ):
        super().__init__()

        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.decoder = FNet(num_layers, input_shape, dim_ff, dropout)
        self.fc_bloc = Sequential(
            Linear(input_shape, input_shape // (div := 2)),  # 1024 > 512
            GELU(),
            Dropout(dropout),
            Linear(input_shape // div, input_shape // (div := div * 4)),  # 512 > 128
            GELU(),
            Dropout(dropout),
            Linear(input_shape // div, 1),  # 128 > 1
        )

    def forward(self, x):
        x = self.decoder(x)
        x = self.fc_bloc(x)
        x = torch.tanh(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.mse_loss(logits, y)
        mse = mean_squared_error(logits, y)
        mae = mean_absolute_error(logits, y)
        k = t_kendalltau(logits, y)
        p = t_pearson(logits, y)
        s = t_spearman(logits, y)

        self.log("train_loss", loss, on_step=True, on_epoch=True, logger=True)
        self.log("train_mse", mse, on_step=True, on_epoch=True, logger=True)
        self.log("train_mae", mae, on_step=True, on_epoch=True, logger=True)
        self.log("train_kendalltau", k, on_step=False, on_epoch=True, logger=True)
        self.log("train_pearson", p, on_step=False, on_epoch=True, logger=True)
        self.log("train_spearman", s, on_step=False, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.mse_loss(logits, y)
        mse = mean_squared_error(logits, y)
        mae = mean_absolute_error(logits, y)
        k = t_kendalltau(logits, y)
        p = t_pearson(logits, y)
        s = t_spearman(logits, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_mse", mse, prog_bar=True)
        self.log("val_mae", mae, prog_bar=True)
        self.log("val_kendalltau", k, on_step=False, prog_bar=True)
        self.log("val_pearson", p, on_step=False, prog_bar=True)
        self.log("val_spearman", s, on_step=False, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.mse_loss(logits, y)
        mse = mean_squared_error(logits, y)
        mae = mean_absolute_error(logits, y)
        k = t_kendalltau(logits, y)
        p = t_pearson(logits, y)
        s = t_spearman(logits, y)
        self.log("test_loss", loss, on_step=True, prog_bar=True)
        self.log("test_mse", mse, on_step=True, prog_bar=True)
        self.log("test_mae", mae, on_step=True, prog_bar=True)
        self.log("test_kendalltau", k, prog_bar=True)
        self.log("test_pearson", p, prog_bar=True)
        self.log("test_spearman", s, prog_bar=True)

        return loss

    def configure_optimizers(self):
        # optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        optimizer = RAdam(self.parameters(), lr=self.learning_rate)
        return optimizer


# Training + Predictions

In [43]:
dm = TextMiningDataModule(256, "de-en", 1024*3)
dm.setup()

In [44]:
# val_samples = next(iter(dm.train_dataloader()))
# val_imgs, val_score = val_samples
# val_imgs.shape, val_score.shape

In [45]:
model = Model(dm.dims, learning_rate=0.001, dim_ff=4096)
# wandb_logger = WandbLogger(project="wandb-lightning", job_type="train")

trainer = pl.Trainer(
    max_epochs=1,
    progress_bar_refresh_rate=1,
    gpus=1,
    # logger=wandb_logger,
    # callbacks=[early_stop_callback],
)
# model

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [46]:
# torch.autograd.set_detect_anomaly(False)
trainer.fit(model, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | decoder | FNet       | 151 M 
1 | fc_bloc | Sequential | 5.3 M 
---------------------------------------
156 M     Trainable params
0         Non-trainable params
156 M     Total params
625.837   Total estimated model params size (MB)


Epoch 0:   1%|          | 1/86 [00:00<00:56,  1.52it/s, loss=0.581, v_num=2, val_loss=0.691, val_mse=0.691, val_mae=0.684, val_kendalltau=-.0528, val_pearson=-.0384, val_spearman=-.0789]

	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value) (Triggered internally at  /opt/conda/conda-bld/pytorch_1616554798336/work/torch/csrc/utils/python_arg_parser.cpp:1005.)
  exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)


Epoch 0: 100%|██████████| 86/86 [00:16<00:00,  5.07it/s, loss=0.742, v_num=2, val_loss=0.728, val_mse=0.728, val_mae=0.690, val_kendalltau=0.095, val_pearson=0.148, val_spearman=0.141]


# Test set

In [None]:
score

In [47]:
trainer.validate()

[autoreload of radam.radam failed: Traceback (most recent call last):
  File "/home/fsx/miniconda3/envs/pl/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/home/fsx/miniconda3/envs/pl/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 410, in superreload
    update_generic(old_obj, new_obj)
  File "/home/fsx/miniconda3/envs/pl/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/home/fsx/miniconda3/envs/pl/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 302, in update_class
    if update_generic(old_obj, new_obj): continue
  File "/home/fsx/miniconda3/envs/pl/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/home/fsx/miniconda3/envs/pl/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 266, in update_function
    setat

--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_kendalltau': 0.09496324509382248,
 'val_loss': 0.7278435826301575,
 'val_mae': 0.6904030442237854,
 'val_mse': 0.7278435826301575,
 'val_pearson': 0.14763259887695312,
 'val_spearman': 0.14055135846138}
--------------------------------------------------------------------------------


[{'val_loss': 0.7278435826301575,
  'val_mse': 0.7278435826301575,
  'val_mae': 0.6904030442237854,
  'val_kendalltau': 0.09496324509382248,
  'val_pearson': 0.14763259887695312,
  'val_spearman': 0.14055135846138}]