# Imports

In [85]:
import pandas as pd
import numpy as np
import os
from typing import List, Union, Dict
import torch
from torch import nn, tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, TensorDataset
import re

# pytorch related imports
import torch
from torch.nn import *
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

# lightning related imports
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

# metrics
from torchmetrics.functional import mean_absolute_error, mean_squared_error
from metrics.regression_metrics import t_kendalltau, t_pearson, t_spearman

# sklearn related imports
from sklearn.metrics import precision_recall_curve

from fnet import *

import wandb


# Dataset & DataLoader

In [2]:
srt = ["source", "reference", "translation"]
language_pairs = [
    "cs-en",
    "de-en",
    "en-fi",
    "en-zh",
    "ru-en",
    "zh-en",
]
scores = {pair: pd.read_csv(f"corpus/{pair}/scores.csv") for pair in language_pairs}

In [3]:
pair = "de-en"
embedding_ref = torch.from_numpy(np.load(f"corpus/{pair}/laser.reference_embeds.npy")).float()
embedding_src = torch.from_numpy(np.load(f"corpus/{pair}/laser.source_embeds.npy")).float()
embedding_hyp = torch.from_numpy(np.load(f"corpus/{pair}/laser.translation_embeds.npy")).float()
score = tensor(scores[pair]["z-score"]).unsqueeze(1).float()

In [4]:
embedding_ref.shape

torch.Size([21704, 1024])

In [5]:
img_ref = torch.reshape(embedding_ref, (embedding_ref.shape[0], 32, 32)).unsqueeze(1)
img_src = torch.reshape(embedding_src, (embedding_ref.shape[0], 32, 32)).unsqueeze(1)
img_hyp = torch.reshape(embedding_hyp, (embedding_ref.shape[0], 32, 32)).unsqueeze(1)
img_ref.shape

torch.Size([21704, 1, 32, 32])

In [6]:
imgs = torch.cat([img_src, img_ref, img_hyp], dim=1)
print(imgs.shape)
print(score.shape)

torch.Size([21704, 3, 32, 32])
torch.Size([21704, 1])


In [50]:
EMBEDDING_SHAPE1 = 1024
EMBEDDING_SHAPE2 = 1


class TextMiningDataModule(pl.LightningDataModule):
    def __init__(
        self, batch_size, pair, dims,
    ):
        super().__init__(dims=dims,)
        self.batch_size = batch_size
        self.pair = pair

    def prepare_data(self):
        pass

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            embedding_ref = torch.from_numpy(
                np.load(f"corpus/{self.pair}/laser.reference_embeds.npy")
            ).float()
            embedding_src = torch.from_numpy(
                np.load(f"corpus/{self.pair}/laser.source_embeds.npy")
            ).float()
            embedding_hyp = torch.from_numpy(
                np.load(f"corpus/{self.pair}/laser.translation_embeds.npy")
            ).float()
            train_size = embedding_ref.shape[0]
            train_img_ref = torch.reshape(
                embedding_ref, (train_size, EMBEDDING_SHAPE1, EMBEDDING_SHAPE2)
            ).unsqueeze(1)
            train_img_src = torch.reshape(
                embedding_src, (train_size, EMBEDDING_SHAPE1, EMBEDDING_SHAPE2)
            ).unsqueeze(1)
            train_img_hyp = torch.reshape(
                embedding_hyp, (train_size, EMBEDDING_SHAPE1, EMBEDDING_SHAPE2)
            ).unsqueeze(1)
            train_imgs = torch.cat([train_img_src, train_img_ref, train_img_hyp], dim=1)
            # train_imgs = torch.cat([train_img_ref, train_img_hyp], dim=1)

            train_target = tensor(scores[self.pair]["z-score"]).unsqueeze(1).float()

            ds_full = TensorDataset(train_imgs, train_target)
            self.ds_train, self.ds_val = random_split(
                ds_full,
                [int(train_size * (0.9)), int(train_size - int(train_size * (0.9)))],
            )
        if stage == "test" or stage is None:
            embedding_ref = torch.from_numpy(
                np.load(f"testset/{self.pair}/laser.reference_embeds.npy")
            ).float()
            embedding_src = torch.from_numpy(
                np.load(f"testset/{self.pair}/laser.source_embeds.npy")
            ).float()
            embedding_hyp = torch.from_numpy(
                np.load(f"testset/{self.pair}/laser.translation_embeds.npy")
            ).float()
            test_size = embedding_ref.shape[0]
            test_img_ref = torch.reshape(
                embedding_ref, (test_size, EMBEDDING_SHAPE1, EMBEDDING_SHAPE2)
            ).unsqueeze(1)
            test_img_src = torch.reshape(
                embedding_src, (test_size, EMBEDDING_SHAPE1, EMBEDDING_SHAPE2)
            ).unsqueeze(1)
            test_img_hyp = torch.reshape(
                embedding_hyp, (test_size, EMBEDDING_SHAPE1, EMBEDDING_SHAPE2)
            ).unsqueeze(1)
            test_imgs = torch.cat([test_img_src, test_img_ref, test_img_hyp], dim=1)
            # test_imgs = torch.cat([test_img_ref, test_img_hyp], dim=1)
            self.ds_test = TensorDataset(test_imgs, torch.zeros((test_size, 1)))

    def train_dataloader(
        self,
    ) -> Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]:
        return DataLoader(
            self.ds_train, batch_size=self.batch_size, shuffle=True, num_workers=12
        )

    def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
        return DataLoader(self.ds_val, batch_size=self.batch_size, num_workers=12)

    def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
        return DataLoader(self.ds_test, batch_size=self.batch_size, num_workers=12)



# Loggers

In [51]:
early_stop_callback = EarlyStopping(monitor='val_loss', patience=5, verbose=False, mode='min')

In [52]:
class TranslationPredictionLogger(Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_score = val_samples

    def on_validation_batch_end(
        self, trainer, pl_module,
    ):
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_score = self.val_score.to(device=pl_module.device)

        predictions = pl_module(val_imgs)
        trainer.logger.experiment.log(
            {
                "examples": [
                    wandb.Image(x, caption=f"Prediction:{p}, Score: {y}")
                    for x, p, y in zip(
                        val_imgs[: self.num_samples],
                        predictions[: self.num_samples],
                        val_score[: self.num_samples],
                    )
                ]
            }
        )


# Modules

## Convolutional

In [66]:
FEATURES_1 = 64
FEATURES_2 = FEATURES_1 * 2
FEATURES_3 = FEATURES_2 * 2
FEATURES_4 = FEATURES_3 * 2
FEATURES_5 = FEATURES_4 * 2

class Model(pl.LightningModule):
    def __init__(self, input_shape, learning_rate=0.001):
        super().__init__()

        self.save_hyperparameters()
        self.learning_rate = learning_rate

        self.c1 = Conv2d(3, FEATURES_1, (3, 1), 1, (1, 0))
        self.c2 = Conv2d(FEATURES_1, FEATURES_1, (3, 1), 1, (1, 0)) # 32
        self.c3 = Conv2d(FEATURES_1, FEATURES_2, (3, 1), 4, (1, 0)) # 24
        self.c4 = Conv2d(FEATURES_2, FEATURES_3, (3, 1), 2, (1, 0)) # 16
        self.c5 = Conv2d(FEATURES_3, FEATURES_4, (3, 1), 2, (1, 0)) # 8
        self.c6 = Conv2d(FEATURES_4, FEATURES_5, (3, 1), 2, (1, 0)) # 2
        self.bn = BatchNorm2d(FEATURES_1)
        self.bn3 = BatchNorm2d(FEATURES_2)
        self.bn4 = BatchNorm2d(FEATURES_3)
        self.bn5 = BatchNorm2d(FEATURES_4)
        self.bn6 = BatchNorm2d(FEATURES_5)
        self.fc1 = Linear(32 * FEATURES_5, 256)
        self.fc2 = Linear(256, 64)
        self.fc3 = Linear(64, 1)

    def _forward_features(self, x):
        x = self.bn(F.relu(self.c1(x)))
        # x = self.bn(F.relu(self.c2(x) + x))
        # x = self.bn(F.relu(self.c2(x) + x))
        # x = self.bn(F.relu(self.c2(x) + x))
        x = self.bn(F.relu(self.c2(x)))
        x = self.bn3(F.relu(self.c3(x)))
        x = self.bn4(F.relu(self.c4(x)))
        x = self.bn5(F.relu(self.c5(x)))
        # x = self.bn6(F.relu(self.c6(x)))
        return x

    def forward(self, x):
        x = self._forward_features(x)
        # print(x.shape)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.mse_loss(logits, y)
        mse = mean_squared_error(logits, y)
        mae = mean_absolute_error(logits, y)
        k = t_kendalltau(logits, y)
        p = t_pearson(logits, y)
        s = t_spearman(logits, y)

        self.log("train_loss", loss, on_step=True, on_epoch=True, logger=True)
        self.log("train_mse", mse, on_step=True, on_epoch=True, logger=True)
        self.log("train_mae", mae, on_step=True, on_epoch=True, logger=True)
        self.log("train_kendalltau", k, on_step=False, on_epoch=True, logger=True)
        self.log("train_pearson", p, on_step=False, on_epoch=True, logger=True)
        self.log("train_spearman", s, on_step=False, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.mse_loss(logits, y)
        mse = mean_squared_error(logits, y)
        mae = mean_absolute_error(logits, y)
        k = t_kendalltau(logits, y)
        p = t_pearson(logits, y)
        s = t_spearman(logits, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_mse", mse, prog_bar=True)
        self.log("val_mae", mae, prog_bar=True)
        self.log("val_kendalltau", k, on_step=False, prog_bar=True)
        self.log("val_pearson", p, on_step=False, prog_bar=True)
        self.log("val_spearman", s, on_step=False, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.mse_loss(logits, y)
        mse = mean_squared_error(logits, y)
        mae = mean_absolute_error(logits, y)
        k = t_kendalltau(logits, y)
        p = t_pearson(logits, y)
        s = t_spearman(logits, y)
        self.log("test_loss", loss, on_step=True, prog_bar=True)
        self.log("test_mse", mse, on_step=True, prog_bar=True)
        self.log("test_mae", mae, on_step=True, prog_bar=True)
        self.log("test_kendalltau", k, prog_bar=True)
        self.log("test_pearson", p, prog_bar=True)
        self.log("test_spearman", s, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer


## FourierTransformerDecoder

In [84]:
def fourier_transform2d(x: torch.Tensor) -> torch.Tensor:
    """Applies 2d fourier transform

    Args:
        x (torch.Tensor):

    Returns:
        torch.Tensor: 
    """
    return torch.fft.fft2(x, dim=(-1, -2)).real

In [None]:
class DecoderLayer(Module):
    def __init__(self, d_model: int, dim_ff:int=2048, dropout:float = 0.1):
        super().__init__()
        self.ff = Sequential(
            Linear(d_model, dim_ff),
            GELU(),
            Dropout(dropout),
            Linear(dim_ff, d_model)
            Dropout(dropout)
        )
        self.norm1 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.norm2 = LayerNorm(d_model)
        self.dropout2 = Dropout(dropout)
        self.norm3 = LayerNorm(d_model)
    
    def forward(self, x):
        i = x
        x += self.dropout1(fourier_transform2d(x))
        x = self.norm1(x)
        x += self.dropout2(i)
        x = self.norm2(x)
        x += self.ff(x)
        x = self.norm3(x)
        return x

class FNet(TransformerDecoder):
    def __init__(self, num_layers, d_model: int,dim_ff:int=2048, dropout:float = 0.1):
        decoder_layer = DecoderLayer(d_model, dim_ff, dropout)
        super().__init__(decoder_layer, num_layers)
        self.num_layers = num_layers

In [66]:
FEATURES_1 = 64
FEATURES_2 = FEATURES_1 * 2
FEATURES_3 = FEATURES_2 * 2
FEATURES_4 = FEATURES_3 * 2
FEATURES_5 = FEATURES_4 * 2

class Decoder(TransformerDecoder):
    def __init__(self, decoder_layer, num_layers, norm):
        super().__init__(decoder_layer, num_layers, norm=norm)

class Model(pl.LightningModule):
    def __init__(self, input_shape, learning_rate=0.001):
        super().__init__()

        self.save_hyperparameters()
        self.learning_rate = learning_rate

        self.c1 = Conv2d(3, FEATURES_1, (3, 1), 1, (1, 0))
        self.c2 = Conv2d(FEATURES_1, FEATURES_1, (3, 1), 1, (1, 0)) # 32
        self.c3 = Conv2d(FEATURES_1, FEATURES_2, (3, 1), 4, (1, 0)) # 24
        self.c4 = Conv2d(FEATURES_2, FEATURES_3, (3, 1), 2, (1, 0)) # 16
        self.c5 = Conv2d(FEATURES_3, FEATURES_4, (3, 1), 2, (1, 0)) # 8
        self.c6 = Conv2d(FEATURES_4, FEATURES_5, (3, 1), 2, (1, 0)) # 2
        self.bn = BatchNorm2d(FEATURES_1)
        self.bn3 = BatchNorm2d(FEATURES_2)
        self.bn4 = BatchNorm2d(FEATURES_3)
        self.bn5 = BatchNorm2d(FEATURES_4)
        self.bn6 = BatchNorm2d(FEATURES_5)
        self.fc1 = Linear(32 * FEATURES_5, 1)
        self.fc3 = Linear(64, 1)

    def _forward_features(self, x):
        x = self.bn(F.relu(self.c1(x)))
        # x = self.bn(F.relu(self.c2(x) + x))
        # x = self.bn(F.relu(self.c2(x) + x))
        # x = self.bn(F.relu(self.c2(x) + x))
        x = self.bn(F.relu(self.c2(x)))
        x = self.bn3(F.relu(self.c3(x)))
        x = self.bn4(F.relu(self.c4(x)))
        x = self.bn5(F.relu(self.c5(x)))
        # x = self.bn6(F.relu(self.c6(x)))
        return x

    def forward(self, x):
        x = self._forward_features(x)
        # print(x.shape)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.mse_loss(logits, y)
        mse = mean_squared_error(logits, y)
        mae = mean_absolute_error(logits, y)
        k = t_kendalltau(logits, y)
        p = t_pearson(logits, y)
        s = t_spearman(logits, y)

        self.log("train_loss", loss, on_step=True, on_epoch=True, logger=True)
        self.log("train_mse", mse, on_step=True, on_epoch=True, logger=True)
        self.log("train_mae", mae, on_step=True, on_epoch=True, logger=True)
        self.log("train_kendalltau", k, on_step=False, on_epoch=True, logger=True)
        self.log("train_pearson", p, on_step=False, on_epoch=True, logger=True)
        self.log("train_spearman", s, on_step=False, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.mse_loss(logits, y)
        mse = mean_squared_error(logits, y)
        mae = mean_absolute_error(logits, y)
        k = t_kendalltau(logits, y)
        p = t_pearson(logits, y)
        s = t_spearman(logits, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_mse", mse, prog_bar=True)
        self.log("val_mae", mae, prog_bar=True)
        self.log("val_kendalltau", k, on_step=False, prog_bar=True)
        self.log("val_pearson", p, on_step=False, prog_bar=True)
        self.log("val_spearman", s, on_step=False, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.mse_loss(logits, y)
        mse = mean_squared_error(logits, y)
        mae = mean_absolute_error(logits, y)
        k = t_kendalltau(logits, y)
        p = t_pearson(logits, y)
        s = t_spearman(logits, y)
        self.log("test_loss", loss, on_step=True, prog_bar=True)
        self.log("test_mse", mse, on_step=True, prog_bar=True)
        self.log("test_mae", mae, on_step=True, prog_bar=True)
        self.log("test_kendalltau", k, prog_bar=True)
        self.log("test_pearson", p, prog_bar=True)
        self.log("test_spearman", s, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer


# Training + Predictions

In [67]:
dm = TextMiningDataModule(256, "de-en", (3,32,32))
dm.setup()

In [79]:
# val_samples = next(iter(dm.train_dataloader()))
# val_imgs, val_score = val_samples
# val_imgs.shape, val_score.shape

In [76]:
model = Model(dm.dims, learning_rate=0.0001)
# wandb_logger = WandbLogger(project="wandb-lightning", job_type="train")

trainer = pl.Trainer(
    max_epochs=30,
    progress_bar_refresh_rate=1,
    gpus=1,
    # logger=wandb_logger,
    # callbacks=[early_stop_callback],
)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [77]:
trainer.fit(model, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name | Type        | Params
--------------------------------------
0  | c1   | Conv2d      | 640   
1  | c2   | Conv2d      | 12.4 K
2  | c3   | Conv2d      | 24.7 K
3  | c4   | Conv2d      | 98.6 K
4  | c5   | Conv2d      | 393 K 
5  | c6   | Conv2d      | 1.6 M 
6  | bn   | BatchNorm2d | 128   
7  | bn3  | BatchNorm2d | 256   
8  | bn4  | BatchNorm2d | 512   
9  | bn5  | BatchNorm2d | 1.0 K 
10 | bn6  | BatchNorm2d | 2.0 K 
11 | fc1  | Linear      | 8.4 M 
12 | fc2  | Linear      | 16.4 K
13 | fc3  | Linear      | 65    
--------------------------------------
10.5 M    Trainable params
0         Non-trainable params
10.5 M    Total params
42.053    Total estimated model params size (MB)


Epoch 29: 100%|██████████| 86/86 [00:07<00:00, 12.15it/s, loss=0.106, v_num=42, val_loss=1.690, val_mse=1.690, val_mae=1.140, val_kendalltau=0.0253, val_pearson=0.0248, val_spearman=0.0368]


# Test set

In [71]:
trainer.validate()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_kendalltau': 0.041737161576747894,
 'val_loss': 1.622091293334961,
 'val_mae': 1.1172127723693848,
 'val_mse': 1.622091293334961,
 'val_pearson': 0.09263288229703903,
 'val_spearman': 0.06260263919830322}
--------------------------------------------------------------------------------


[{'val_loss': 1.622091293334961,
  'val_mse': 1.622091293334961,
  'val_mae': 1.1172127723693848,
  'val_kendalltau': 0.041737161576747894,
  'val_pearson': 0.09263288229703903,
  'val_spearman': 0.06260263919830322}]