# Imports and installation


In [5]:
%%capture
!pip install datasets transformers accelerate evaluate wandb nltk pandas lightning

In [15]:
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from torch.utils.data import DataLoader
import random
from dataclasses import dataclass
import numpy as np
from transformers import BartTokenizer, BartForConditionalGeneration, T5ForConditionalGeneration, AutoTokenizer
import lightning as L

BATCH_SIZE = 128
SEED = 124

#DATA USED
SHORT = False
RANDOMIZED_SHORT = True
MEDIUM = False

MAX_SEQ_LEN = 512 if MEDIUM else 256

#MODEL USED
FEEDFORWARD = False
FEEDFORWARD_WITH_ATTENTION = False
CONV1D = False
RNN = True
SEQ2SEQ = False

#MODEL CHOICES FOR SEQ2SEQ: bart-base, bart-large, t5-base
MODEL = "bart-base"

#RNN MODELS AND HYPERPARAMETERS
BIDIRECTIONAL = False
RNN_TYPE = 'RNN'  # Options: 'LSTM', 'GRU', 'RNN'

#HYPERPARAMETERS
EMBED_DIM = 128
HIDDEN_DIM = 512
LEARNING_RATE = 5e-4
DROPOUT_RATE = 0.5
NUM_HEADS = 4
NUM_LAYERS = 4
WEIGHT_DECAY = 0.01
MAX_EPOCHS = 5

torch.manual_seed(SEED)
L.seed_everything(SEED)

INFO: Seed set to 124
INFO:lightning.fabric.utilities.seed:Seed set to 124


124

In [16]:
models_values = [FEEDFORWARD, FEEDFORWARD_WITH_ATTENTION, CONV1D, RNN, SEQ2SEQ]
num_true = sum(models_values)

# Check if only one value is True
if num_true == 1:
    print("OK")
else:
    print("ATTENTION! SELECT ONLY ONE MODEL TO RUN")
    ## Using this so that is more than one model is selected the execution does not continue
    print(UNDEFINED_VARIABLE_TO_LET_THE_NOTEBOOK_CRASH)

OK


In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Data Preparation

In [18]:
## Mapping from token to id used for encoding hexadecimal strings
token2id = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "a": 10, "b": 11, "c": 12, "d": 13, "e": 14, "f": 15, "P":16, "S": 17, "E":18 }
def create_id2token_vocab(token_to_id):
    id2token = {}
    for token, id in token_to_id.items():
        id2token[id] = token

    return id2token

id2token = create_id2token_vocab(token2id)

## INIZIALIZE OUTPUT DIM NOW THAT I KNOW THE LENGTH OF THE TOKEN2ID DICTIONARY
OUTPUT_DIM = len(token2id)

In [19]:
!wget -O /content/datasets.zip https://github.com/Tommaiberone/Zip-generation/raw/main/Datasets/datasets.zip
!unzip -o /content/datasets.zip

--2024-03-29 17:00:01--  https://github.com/Tommaiberone/Zip-generation/raw/main/Datasets/datasets.zip
Resolving github.com (github.com)... 140.82.112.3
Connecting to github.com (github.com)|140.82.112.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/Tommaiberone/Zip-generation/main/Datasets/datasets.zip [following]
--2024-03-29 17:00:01--  https://raw.githubusercontent.com/Tommaiberone/Zip-generation/main/Datasets/datasets.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 8767992 (8.4M) [application/zip]
Saving to: ‘/content/datasets.zip’


2024-03-29 17:00:01 (99.1 MB/s) - ‘/content/datasets.zip’ saved [8767992/8767992]

Archive:  /content/datasets.zip
  inflating: mediumhex2hex.csv       

In [20]:
if SHORT:
  df = pd.read_csv('/content/mediumhex2hex.csv')
elif RANDOMIZED_SHORT:
  df = pd.read_csv('/content/randomized_shorthex2hex.csv')
elif MEDIUM:
  df = pd.read_csv('/content/shorthex2hex.csv')


df = df[:40960]

In [21]:
if RNN:
  df['text_hex'] = 'S' + df['text_hex'] + 'E'
  df['deflate_hex'] = 'S' + df['deflate_hex'] + 'E'

df.head()

Unnamed: 0,text,text_hex,deflate_hex
0,this is not a,S74686973206973206e6f742061E,S789c2bc9c82c5600a2bcfc1285440021fe04a7E
1,"and gives a comforting,",S616e64206769766573206120636f6d666f7274696e672cE,S789c4bcc4b5148cf2c4b2d56485448cecf4dcb2f2ac9c...
2,killer). While some may,S6b696c6c6572292e205768696c6520736f6d65206d6179E,S789ccbceccc9492dd2d45308cfc8cc495528cecf4d55c...
3,in his closet &,S696e2068697320636c6f7365742026E,S789ccbcc53c8c82c5648cec92f4e2d515003002b16052cE
4,film to watch. Mr.,S66696c6d20746f2077617463682e204d722eE,S789c4bcbccc95528c957284f2c49ced053f02dd203003...


Instead of using the standard \<EOS> and \<SOS> tags we're using the letter S and E since they are not present in the vocabulary

In [22]:
if SEQ2SEQ:
    df = df[:15000]
    df[['deflate_hex', 'text_hex', 'text']] += "</s>"

ds = Dataset.from_pandas(df)
ds_train_test = ds.train_test_split(test_size=0.2, seed=SEED)
ds_test_dev = ds_train_test['test'].train_test_split(test_size=0.5, seed=SEED)
ds_splits = DatasetDict({
    'train': ds_train_test['train'],
    'valid': ds_test_dev['train'],
    'test': ds_test_dev['test']
})

## Data tokenization

In [23]:
if SEQ2SEQ:
    if (MODEL == "bart-base"):
        tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
        model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

    elif (MODEL == "bart-large"):
        tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
        model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")

    else:
        tokenizer = AutoTokenizer.from_pretrained("t5-base")
        model = T5ForConditionalGeneration.from_pretrained("t5-base")

In [24]:
def collate_fn(batch):

    if FEEDFORWARD or FEEDFORWARD_WITH_ATTENTION or CONV1D:

        texts = [elem['text_hex'] for elem in batch]
        encoded_hexs = [[token2id[char] for char in text] for text in texts]

        outputs = [elem['deflate_hex'] for elem in batch]
        encoded_outputs = [[token2id[char] for char in output] for output in outputs]

        ## Pad the sequences to MAX_SEQ_LEN chars with the padding token
        padded_hex = [torch.Tensor(encoded_hex + [token2id["P"]] * (MAX_SEQ_LEN - len(encoded_hex))) for encoded_hex in encoded_hexs]
        padded_outputs = [torch.Tensor(encoded_output + [token2id["P"]] * (MAX_SEQ_LEN - len(encoded_output))) for encoded_output in encoded_outputs]

        ## Stack the sequences
        padded_hex = torch.stack(padded_hex).long()
        padded_outputs = torch.stack(padded_outputs).long()


        return {
            'inputs': padded_hex,
            'outputs': padded_outputs
        }

    elif RNN:

        ## Dynamic padding for RNNs
        def pad_sequences(sequences, maxlen, value=token2id['P']):
            padded_sequences = []
            for sequence in sequences:
                padded_sequence = sequence[:maxlen]
                padded_sequence.extend([value] * (maxlen - len(padded_sequence)))

                padded_sequence = sequence +  [value] * (maxlen - len(sequence))
                padded_sequences.append(padded_sequence)

            return padded_sequences


        texts = [elem['text_hex'] for elem in batch]
        encoded_hex = [[token2id[x] for x in hex] for hex in texts]


        outputs = [elem['deflate_hex'] for elem in batch]
        encoded_outputs = [[token2id[x] for x in hex] for hex in outputs]


        maxlen = 0
        for seq in encoded_hex:
            if len(seq) > maxlen:
                maxlen = len(seq)
        for seq in encoded_outputs:
            if len(seq) > maxlen:
                maxlen = len(seq)

        padded_encoded_hex = pad_sequences(encoded_hex, maxlen)
        padded_encoded_outputs = pad_sequences(encoded_outputs, maxlen)


        return {
            'inputs': torch.tensor(padded_encoded_hex),
            "outputs": torch.tensor(padded_encoded_outputs)
        }

    elif SEQ2SEQ:
        inputs = [x["text_hex"] for x in batch]
        outputs = [x["deflate_hex"] for x in batch]
        input_features = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, max_length=MAX_SEQ_LEN)
        output_features = tokenizer(outputs, return_tensors="pt", padding=True, truncation=True, max_length=MAX_SEQ_LEN)["input_ids"]
        output_features[output_features == tokenizer.pad_token_id] = -100
        return {"input_ids": input_features["input_ids"], "attention_mask": input_features["attention_mask"], "labels": output_features}


## Initializing dataloaders

In [25]:
train_dataloader = DataLoader(ds_splits['train'], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers = 3)
val_dataloader = DataLoader(ds_splits['valid'], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers = 3)
test_dataloader = DataLoader(ds_splits['test'], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers = 3)

torch.set_printoptions(profile="full")



### Evaluation functions

We're using the nltk library to compute the edit distance (Levenshtein distance) between the predicted string and the target/gold string.

In [26]:
import nltk
from nltk.metrics.distance import edit_distance

def decode_output(output):
    return ''.join([id2token[int(id)] for id in output])

def decode_input(input):
    return ''.join([id2token[int(id)] for id in input])

## function used to compute metrics for Seq2Seq models (bart/t5)
def compute_seq2seq_metrics(preds, labels, tokenizer):
    # Ensure labels with -100 are replaced by pad_token_id
    labels = torch.where(labels == -100, tokenizer.pad_token_id, labels)

    # Convert tensors to lists and detach them from cuda
    if torch.is_tensor(preds):
        preds = preds.detach().cpu().tolist()
    if torch.is_tensor(labels):
        labels = labels.detach().cpu().tolist()

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    distances = [edit_distance(p, l) for p, l in zip(decoded_preds, decoded_labels)]
    avg_distance = np.mean(distances)
    count_unzippable = distances.count(0)

    return {"average_edit_distance": avg_distance, "count_unzippable": count_unzippable}

## function used to compute metrics for all the other models
def evaluate(_device, _print, _training):

    model.eval()
    total_distance = 0
    total = 0

    distances_list = []


    for batch in test_dataloader:

        if FEEDFORWARD or FEEDFORWARD_WITH_ATTENTION or CONV1D:
            x = batch["inputs"].to(_device)
            y = batch["outputs"].to(_device)

            y_hat = model(x, y)
            y_hat = torch.argmax(y_hat, dim=-1)

            output = decode_output(y[0])
            output_hat = decode_output(y_hat[0])

            output = [x for x in output if x != "P"]
            output_hat = [x for x in output_hat if x != "P"]

            distance = edit_distance(output, output_hat)

        elif RNN:
            x = batch["inputs"].transpose(0,1).to(_device)
            y = batch["outputs"].transpose(0,1).to(_device)

            y_hat = model(x, y)
            y_hat = torch.argmax(y_hat, dim=-1)

            y = y.transpose(0,1)
            y_hat = y_hat.transpose(0,1)

            assert len(y) == len(y_hat)

            for i in range(len(y)):
                output = decode_output(y[i])
                output_hat = decode_output(y_hat[i])

                ## Remove padding
                output = [x for x in output if x != "P"]
                output_hat = [x for x in output_hat if x != "P"]

                ## Save the index of the first EOS token, if any. Else consider all the string
                first_eos_index = len(output_hat)
                for i in range(len(output_hat)):
                    if output_hat[i] == "E":
                        first_eos_index = i
                        break

                # Remove SOS token
                output = output[1:]
                output_hat = output_hat[1:first_eos_index]

                ## Compute distance
                distance = edit_distance(output, output_hat)
                distances_list.append(distance)

        if _print:
            print(f"output = {output}")
            print(f"output_hat = {output_hat}")

        total_distance += distance
        total += 1

        if distance == 0:
            print(f"DISTANCE = 0!")
            print(f"output = {output}")
            print(f"output_hat = {output_hat}")

        if _training:
            return total_distance/total

    return (total_distance/total, distances_list)

# Models

In [27]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from pytorch_lightning.callbacks import EarlyStopping
from transformers import get_linear_schedule_with_warmup

## In this class there are the following models:
## 1. Vanilla FeedForward
## 2. FeedForward with Attention
## 3. Conv1D
## 4. Recurrent models
## You can switch between those models using the parameters present in the first cell of this Notebook

class FeedForward(pl.LightningModule):

    def __init__(self, input_dim=MAX_SEQ_LEN, embed_dim = EMBED_DIM, hidden_dim=HIDDEN_DIM, output_dim=OUTPUT_DIM, learning_rate=LEARNING_RATE,
                 dropout_rate=DROPOUT_RATE, bidirectional=BIDIRECTIONAL, num_layers=NUM_LAYERS, optimizer_type=AdamW, scheduler_type=StepLR,
                 scheduler_step_size=5, scheduler_gamma=0.1):
        super().__init__()
        self.save_hyperparameters()

        if FEEDFORWARD or FEEDFORWARD_WITH_ATTENTION:
            self.embed = nn.Embedding(input_dim, embed_dim)
            self.positional_embeddings = nn.Parameter(torch.zeros(BATCH_SIZE, input_dim, embed_dim))
            nn.init.normal_(self.positional_embeddings, mean=0, std=embed_dim ** -0.5)  # Initialize positional embeddings
            self.self_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=NUM_HEADS, dropout=dropout_rate, batch_first = True)
            self.fc1 = nn.Linear(embed_dim, hidden_dim)
            self.norm1 = nn.LayerNorm(hidden_dim)
            self.dropout = nn.Dropout(dropout_rate)
            self.fc2 = nn.Linear(hidden_dim, hidden_dim//2)
            self.norm2 = nn.LayerNorm(hidden_dim//2)
            self.fc3 = nn.Linear(hidden_dim//2, output_dim)

        elif CONV1D:

            # Embedding layer to transform dictionary indices into dense vectors
            self.embedding = nn.Embedding(num_embeddings=input_dim, embedding_dim=embed_dim)

            # Convolutional layers
            self.conv1 = nn.Conv1d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=3, padding=1)

            # Fully connected layers for classification
            self.fc1 = nn.Linear(embed_dim, embed_dim)
            self.fc2 = nn.Linear(embed_dim, output_dim)

            # Hyperparameters
            self.learning_rate = learning_rate

        elif RNN:
            self.rnn_type = RNN_TYPE
            self.embedding = nn.Embedding(output_dim, embed_dim, padding_idx=token2id['P'])

            if self.rnn_type == 'LSTM':
                rnn_cell = nn.LSTM
            elif self.rnn_type == 'GRU':
                rnn_cell = nn.GRU
            else:  # Default to RNN if neither LSTM nor GRU is selected
                rnn_cell = nn.RNN

            self.encoder_rnn = rnn_cell(embed_dim, hidden_dim, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout_rate if num_layers > 1 else 0)
            self.decoder_rnn = rnn_cell(embed_dim, hidden_dim, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout_rate if num_layers > 1 else 0)

            self.dropout = nn.Dropout(dropout_rate)
            self.output_dim = output_dim
            self.linear = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
            self.criterion = nn.CrossEntropyLoss()

        self.loss = nn.CrossEntropyLoss()

    def forward(self, x, target, teacher_forcing_ratio=0.5):

        if FEEDFORWARD or FEEDFORWARD_WITH_ATTENTION:

            # Embedding
            x = self.embed(x)  # Shape: [Batch, Seq_len, Embed_dim]

            if (FEEDFORWARD_WITH_ATTENTION):

                # Add positional embeddings
                positions = self.positional_embeddings
                x = x + positions

                # Self-attention
                attn_output, _ = self.self_attention(x, x, x)

                x = torch.relu(self.norm1(self.fc1(attn_output)))

            else:
                x = torch.relu(self.norm1(self.fc1(x)))

            x = self.dropout(x)
            x = torch.relu(self.norm2(self.fc2(x)))
            x = self.fc3(x)

            return x

        elif CONV1D:

            # Embedding layer
            x = self.embedding(x)

            # Transpose from (batch_size, sequence_length, embedding_dim) to (batch_size, embedding_dim, sequence_length)
            x = x.permute(0, 2, 1)

            x = torch.relu(self.conv1(x))

            x = x.permute(0, 2, 1)

            x = torch.relu(self.fc1(x))

            x = self.fc2(x)

            return x

        elif RNN:
            target_len = target.shape[0]
            batch_size = target.shape[1]
            target_vocab_size = self.output_dim

            outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(self.device)

            x = self.dropout(self.embedding(x))
            rnn_output, h = self.encoder_rnn(x)

            x = target[0]
            for t in range(1, target_len):
                x = self.dropout(self.embedding(x.unsqueeze(0)))
                out, h = self.decoder_rnn(x, h if self.rnn_type in ['LSTM', 'GRU'] else None)
                predictions = self.linear(out)
                predictions = predictions.squeeze(0)
                outputs[t] = predictions
                pred = predictions.argmax(1)
                x = target[t] if random.random() < teacher_forcing_ratio else pred

            return outputs

    def configure_optimizers(self):
        optimizer = self.hparams.optimizer_type(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = self.hparams.scheduler_type(optimizer, step_size=self.hparams.scheduler_step_size, gamma=self.hparams.scheduler_gamma)
        return [optimizer], [scheduler]

    def step(self, batch):
        if (FEEDFORWARD or FEEDFORWARD_WITH_ATTENTION or CONV1D):
            x = batch["inputs"]
            y = batch["outputs"]
            y = y.view(y.shape[0] * y.shape[1])
            y_hat = self(x, y)
            y_hat = y_hat.view(y_hat.shape[0] * y_hat.shape[1], y_hat.shape[2])

        elif RNN:
            inputs, targets = batch['inputs'], batch['outputs']
            inputs = inputs.transpose(0, 1)
            targets = targets.transpose(0, 1)

            output = self(inputs, targets)
            output_dim = output.shape[-1]

            y_hat = output.reshape(-1, output_dim)
            y = targets.reshape(-1)

        loss = self.loss(y_hat, y)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.step(batch)
        self.log('train_loss', loss, prog_bar = True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.step(batch)
        self.log('val_loss', loss, prog_bar = True)
        return loss



# Recurrent Models

In [30]:
class Recurrent(pl.LightningModule):
      def __init__(self, input_dim=MAX_SEQ_LEN, embed_dim = EMBED_DIM, hidden_dim=HIDDEN_DIM, output_dim=OUTPUT_DIM, learning_rate=LEARNING_RATE,
                 dropout_rate=DROPOUT_RATE, bidirectional=BIDIRECTIONAL, num_layers=NUM_LAYERS, optimizer_type=AdamW, scheduler_type=StepLR,
                 scheduler_step_size=5, scheduler_gamma=0.1):

        super().__init__()
        self.save_hyperparameters()

        self.rnn_type = RNN_TYPE
        self.embedding = nn.Embedding(output_dim, embed_dim, padding_idx=token2id['P'])

        if self.rnn_type == 'LSTM':
          rnn_cell = nn.LSTM
        elif self.rnn_type == 'GRU':
          rnn_cell = nn.GRU
        else:  # Default to RNN if neither LSTM nor GRU is selected
          rnn_cell = nn.RNN

        self.encoder_rnn = rnn_cell(embed_dim, hidden_dim, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout_rate if num_layers > 1 else 0)
        self.decoder_rnn = rnn_cell(embed_dim, hidden_dim, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout_rate if num_layers > 1 else 0)

        self.dropout = nn.Dropout(dropout_rate)
        self.output_dim = output_dim
        self.linear = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        self.criterion = nn.CrossEntropyLoss()

        self.loss = nn.CrossEntropyLoss()



      def forward(self, x, target, teacher_forcing_ratio=0.5):
        target_len = target.shape[0]
        batch_size = target.shape[1]
        target_vocab_size = self.output_dim

        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(self.device)

        x = self.dropout(self.embedding(x))
        rnn_output, h = self.encoder_rnn(x)

        x = target[0]
        for t in range(1, target_len):
          x = self.dropout(self.embedding(x.unsqueeze(0)))
          out, h = self.decoder_rnn(x, h if self.rnn_type in ['LSTM', 'GRU'] else None)
          predictions = self.linear(out)
          predictions = predictions.squeeze(0)
          outputs[t] = predictions
          pred = predictions.argmax(1)
          x = target[t] if random.random() < teacher_forcing_ratio else pred

        return outputs

      def configure_optimizers(self):
        optimizer = self.hparams.optimizer_type(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = self.hparams.scheduler_type(optimizer, step_size=self.hparams.scheduler_step_size, gamma=self.hparams.scheduler_gamma)
        return [optimizer], [scheduler]

      def step(self, batch):
        inputs, targets = batch['inputs'], batch['outputs']
        inputs = inputs.transpose(0, 1)
        targets = targets.transpose(0, 1)

        output = self(inputs, targets)
        output_dim = output.shape[-1]

        y_hat = output.reshape(-1, output_dim)
        y = targets.reshape(-1)

        loss = self.loss(y_hat, y)
        return loss

      def training_step(self, batch, batch_idx):
          loss = self.step(batch)
          self.log('train_loss', loss, prog_bar = True)
          return loss

      def validation_step(self, batch, batch_idx):
          loss = self.step(batch)
          self.log('val_loss', loss, prog_bar = True)
          return loss


# Seq2Seq Models

In [31]:
class Seq2Seq(pl.LightningModule):
    def __init__(self, tokenizer, model):
        super().__init__()
        self.tokenizer = tokenizer
        self.model = model

    def forward(self, input_ids, attention_mask=None, labels=None):
        return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    def training_step(self, batch, batch_idx):
        outputs = self.forward(**batch)
        self.log('train_loss', outputs.loss, prog_bar=True, logger=True)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        outputs = self.forward(**batch)
        self.log('val_loss', outputs.loss, prog_bar=True, logger=True)

        preds = torch.argmax(outputs.logits, dim=-1)
        metrics = compute_seq2seq_metrics(preds, batch['labels'], self.tokenizer)
        for key, value in metrics.items():
            self.log(f'{key}', value, prog_bar=True, logger=True)

        return outputs.loss

    def test_step(self, batch, batch_idx):
        outputs = self.forward(**batch)
        self.log('test_loss', outputs.loss, prog_bar=True, logger=True)

        preds = torch.argmax(outputs.logits, dim=-1)
        metrics = compute_seq2seq_metrics(preds, batch['labels'], self.tokenizer)

        for key, value in metrics.items():
            self.log(f'{key}', value, prog_bar=True, logger=True)
        return outputs.loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        scheduler = {
            'scheduler': get_linear_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=10000),
            'name': 'learning_rate',
            'interval': 'step',
            'frequency': 1
        }
        return [optimizer], [scheduler]


# Train!

In [32]:
if SEQ2SEQ:
    model = Seq2Seq(tokenizer, model)
    trainer = pl.Trainer(
        precision='16-mixed',
        max_epochs=MAX_EPOCHS,
        enable_progress_bar=True,
        callbacks=[EarlyStopping(monitor='val_loss', patience=3)]
    )
    trainer.fit(model, train_dataloader, val_dataloader)
    trainer.test(model, test_dataloader)
elif RNN:
    model = Recurrent()
    trainer = pl.Trainer(max_epochs=MAX_EPOCHS)
    trainer.fit(model, train_dataloader, val_dataloader)
    print(evaluate(_device = "cpu", _print = True, _training= False))
else:
  model = FeedForward()
  trainer = pl.Trainer(max_epochs=MAX_EPOCHS)
  trainer.fit(model, train_dataloader, val_dataloader)
  print(evaluate(_device = "cpu", _print = True, _training= False))

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name        | Type             | Params
-------------------------------------------------
0 | embedding   | Embedding        | 2.4 K 
1 | encoder_rnn | RNN              | 1.9 M 
2 | decoder_rnn | RNN              | 1.9 M 
3 | dropout     | Dropout          | 0     
4 | linear      | Linear           | 9.7 K 
5 | criterion   | CrossEntropyLoss | 0     
6 | loss        | CrossEntropyLoss | 0     
-------------------------------------------------
3.8 M     Trainable params
0         Non-trainable params
3.8 M     Total params
15.286    To

Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)