# Imports and installation


In [1]:
%%capture
!pip install lightning datasets

In [2]:
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from torch.utils.data import DataLoader
import torch.nn as nn
import lightning as L
import random

SEED = 124
BATCH_SIZE = 1024
HIDDEN_SIZE = 512
NUM_LAYERS= 2
torch.manual_seed(SEED)
L.seed_everything(SEED)

INFO: Seed set to 124


124

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Data Preparation

In [4]:
df = pd.read_csv('/kaggle/input/shortnew/shorthex2hex.csv')
df = df[:40960]

In [5]:
df.head()

Unnamed: 0,text,text_hex,deflate_hex
0,One of the other,4f6e65206f6620746865206f74686572,789cf3cf4b55c84f5328c9005240a208002eb405bb
1,A wonderful little production.,4120776f6e64657266756c206c6974746c652070726f64...,789c735428cfcf4b492d4a2bcd51c8c92c29c949552828...
2,I thought this was,492074686f75676874207468697320776173,789cf35428c9c82f4dcf2801d299c50ae589c5003dea06b0
3,Basically there's a family,4261736963616c6c79207468657265277320612066616d...,789c734a2cce4c4eccc9a95428c9482d4a552f56485448...
4,"Petter Mattei's ""Love in",506574746572204d6174746569277320224c6f766520696e,789c0b482d29492d52f04d045299eac50a4a3ef965a90a...


Instead of using the standard \<EOS> and \<SOS> tags we're using the letter S and E since they are not present in the vocabulary

In [6]:
df['text_hex'] = 'S' + df['text_hex'] + 'E'
df['deflate_hex'] = 'S' + df['deflate_hex'] + 'E'

In [7]:
ds = Dataset.from_pandas(df)
ds_train_test = ds.train_test_split(test_size=0.2, seed=SEED)
ds_test_dev = ds_train_test['test'].train_test_split(test_size=0.5, seed=SEED)
ds_splits = DatasetDict({
    'train': ds_train_test['train'],
    'valid': ds_test_dev['train'],
    'test': ds_test_dev['test']
})

ds_splits

DatasetDict({
    train: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 32768
    })
    valid: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 4096
    })
    test: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 4096
    })
})

# Tokenizzare in caratteri singoli o in sequenze di caratteri?

In [8]:
ds_splits['train'][0]

{'text': 'I wonder what audiences',
 'text_hex': 'S4920776f6e64657220776861742061756469656e636573E',
 'deflate_hex': 'S789cf35428cfcf4b492d5228cf482c51482c4dc94ccd4b4e2d060063a8089eE'}

## Data tokenization

In [9]:
token2id = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "a": 10, "b": 11, "c": 12, "d": 13, "e": 14, "f": 15, "P":16, "S": 17, "E":18 }

In [10]:
def create_id2token_vocab(token_to_id):
    id2token = {}
    for token, id in token_to_id.items():
        id2token[id] = token

    return id2token

id2token = create_id2token_vocab(token2id)
id2token

{0: '0',
 1: '1',
 2: '2',
 3: '3',
 4: '4',
 5: '5',
 6: '6',
 7: '7',
 8: '8',
 9: '9',
 10: 'a',
 11: 'b',
 12: 'c',
 13: 'd',
 14: 'e',
 15: 'f',
 16: 'P',
 17: 'S',
 18: 'E'}

In [11]:
# Corrected collate_fn function
def collate_fn(batch):

    texts = [elem['text_hex'] for elem in batch]
    encoded_hex = [[token2id[char] for char in text] for text in texts]

    outputs = [elem['deflate_hex'] for elem in batch]
    encoded_outputs = [[token2id[char] for char in output] for output in outputs]

    #pad sequences
    encoded_hex = nn.utils.rnn.pad_sequence([torch.tensor(seq) for seq in encoded_hex], batch_first=True, padding_value=token2id['P'])
    encoded_outputs = nn.utils.rnn.pad_sequence([torch.tensor(seq) for seq in encoded_outputs], batch_first=True, padding_value=token2id['P'])

    return {
        'inputs': encoded_hex,
        'outputs': encoded_outputs
    }

MAX_SEQ_LEN = 256

In [12]:
train_dataloader = DataLoader(ds_splits['train'], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers = 3)
val_dataloader = DataLoader(ds_splits['valid'], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers = 3)

torch.set_printoptions(profile="full")

for batch in train_dataloader:
    print(batch['inputs'].shape)
    break

torch.Size([1024, 92])


In [18]:
import pytorch_lightning as pl

class myRNN(pl.LightningModule):
    def __init__(self, input_dim, emb_dim, hidden_dim, output_dim, num_layers = NUM_LAYERS, dropout_rate=0.5):
        super().__init__()
        self.emb = nn.Embedding(input_dim, emb_dim).to(device)
        self.rnn = nn.GRU(emb_dim, hidden_dim, batch_first=True, num_layers=num_layers).to(device)
        self.fc = nn.Sequential(nn.Linear(hidden_dim, 256).to(device),
                                nn.ReLU(),
                                nn.Dropout(dropout_rate),  # Adding dropout
                                nn.Linear(256, 64).to(device),
                                nn.ReLU(),
                                nn.Linear(64, output_dim).to(device)
                               )
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=token2id['P']).to(device)
        self.to(device)


    def forward(self, input):
        hidden_state = torch.zeros(NUM_LAYERS, HIDDEN_SIZE, device=device)

        #transpose input
        input = input.transpose(0, 1)

        # Process input sequence
        for char in input:
            embedded_char = self.emb(char) 
            _, hidden_state = self.rnn(embedded_char, hidden_state)

        generated_sequence = []
        input_char = self.emb(torch.tensor([token2id['S']], device=device).long()).repeat(BATCH_SIZE, 1)

        generated_sequence = [MAX_SEQ_LEN, BATCH_SIZE, 19]
        generated_sequence = torch.zeros(generated_sequence, device=device)
        generated_sequence[:, :, -3] = 1  # Set the specific padding element


        seq_len = 0  # To keep track of the actual sequence length
        while seq_len < MAX_SEQ_LEN:
            output, hidden_state = self.rnn(input_char, hidden_state)
            output = self.fc(output)
            
            generated_sequence[seq_len] = output  # Assign output to the tensor directly
            seq_len += 1  # Increment actual sequence length
            
            input_char = self.emb(torch.argmax(output, dim=1))  # Assuming output is a tensor

        return generated_sequence

    def step(self, batch):
        # Assuming 'self' can process batch inputs. 
        # Convert entire batch to device at once (if not already done outside this code snippet).
        inputs, outputs = batch['inputs'].to(device), batch['outputs'].to(device)

        # Predict in batch
        predictions = self(inputs).to(device)  # Assuming self(inputs) returns a batch of sequences
        #transpose predictions
        predictions = predictions.transpose(0, 1)

        padding = [19]
        padding = torch.zeros(padding, device=device)
        padding[-3] = 1  # Set the specific padding element

        #substitute every char in the predictions after the eos token with the padding token
        for i in range(predictions.shape[0]):
            for j in range(predictions.shape[1]):
                if torch.argmax(predictions[i][j]) == token2id['E']:
                    print(predictions[i])
                    predictions[i][j+1:] = padding
                    print(predictions[i])
                    break

        #pad the outputs with padding token to match the length of the predictions
        outputs_padded = [BATCH_SIZE, MAX_SEQ_LEN]
        outputs_padded = torch.zeros(outputs_padded, device=device)
        outputs_padded[:, :] = token2id['P']
        outputs_padded[:, :outputs.shape[1]] = outputs
        
        #Compute the argmax of the predictions
        TO_PRINT_predictions = torch.argmax(predictions, dim=2)

        #print the first element of the predictions and the outputs_padded
        print(f"TO_PRINT_predictions[0] = {TO_PRINT_predictions[0]}")
        print(f"outputs_padded[0] = {outputs_padded[0]}")

        # Reshape predictions and new_output as needed
        predictions = predictions.reshape(-1, 19)
        outputs_padded = outputs_padded.reshape(-1).long()

        return self.loss_fn(predictions, outputs_padded)

    def training_step(self, batch, batch_idx):
        loss = self.step(batch)
        print(f"train_loss = {loss}")
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.step(batch)
        print(f"val_loss = {loss}")
        self.log("val_loss", loss)
        return loss

    
    def configure_optimizers(self):
            # Configure optimizers with weight decay for L2 regularization
            optimizer = torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-5)  # L2 regularization

            # Specify gradient clipping directly in optimizer configuration (PyTorch Lightning handles it)
            optimizer_config = {
                'optimizer': optimizer,
                'gradient_clip_val': 1.0,  # Clip gradients with norm above 1.0
                'gradient_clip_algorithm': 'norm'  # 'norm' uses the total norm of all parameters, 'value' clips by value directly
            }
            return optimizer_config

model = myRNN(input_dim=len(token2id), emb_dim=128, hidden_dim=HIDDEN_SIZE, output_dim=len(token2id)).to(device)

trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_dataloader, val_dataloader)



Sanity Checking: |          | 0/? [00:00<?, ?it/s]

TO_PRINT_predictions[0] = tensor([9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
outputs_padded[0] = tensor([17.,  7.,  8.,  9., 12.,  7.,  3.,  5.,  4.,  4.,  8., 12., 11., 12.,
  

Training: |          | 0/? [00:00<?, ?it/s]

TO_PRINT_predictions[0] = tensor([ 9,  1,  1,  1,  9,  1,  7,  1,  1,  1,  7,  1,  7,  1,  1,  1,  7,  7,
         1,  1,  9,  1,  1,  9,  9,  1,  1,  1,  1,  1,  1,  1, 10,  1,  1,  1,
         1,  1,  7,  9, 10,  9,  1,  7,  9,  1,  1,  9,  1,  1,  7,  1,  1,  1,
         9,  1,  7,  1,  1,  1,  1,  1,  1,  7,  1,  1,  9,  7,  1,  9,  1,  1,
         1,  1,  1,  1,  1,  9,  9,  1,  7,  1,  1,  7,  1,  1,  1,  1,  1,  7,
         1,  1,  1,  9,  7,  1,  1,  1,  1,  1,  1,  1, 10,  1,  1, 10,  1,  1,
         1,  9,  1,  1,  1,  1,  1,  1,  1,  1,  1,  9,  1,  7, 10,  1,  7,  9,
         7,  1,  1,  1,  1,  1,  7,  7,  1,  1,  1,  1,  1,  1,  1,  1,  9,  1,
         1,  7,  1,  1,  7,  1,  1,  1,  1,  1,  7,  1,  1,  7,  1,  1,  1,  1,
         1,  1,  1,  1,  9,  1,  1,  1,  1,  1,  7,  1,  1,  1,  1, 10,  1,  1,
         1,  1,  9,  9,  1,  1,  1,  1,  1,  1,  1,  9,  1,  1,  7,  1,  1,  9,
         1,  1,  1,  1,  1,  9, 10,  9, 10,  1,  1,  7,  1,  1,  1,  1,  1,  1,
         1, 10

Validation: |          | 0/? [00:00<?, ?it/s]

TO_PRINT_predictions[0] = tensor([17,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8

Validation: |          | 0/? [00:00<?, ?it/s]

TO_PRINT_predictions[0] = tensor([17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17,
        17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17,
        17, 17,  7,  9, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12,  0, 12,  0, 12,  0, 12,  0, 12,  0, 12,  0, 12,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0

Validation: |          | 0/? [00:00<?, ?it/s]

TO_PRINT_predictions[0] = tensor([17,  7,  8,  9, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 12, 12, 12, 12, 12, 12, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0

Validation: |          | 0/? [00:00<?, ?it/s]

TO_PRINT_predictions[0] = tensor([17,  7,  8,  9,  9, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0

Validation: |          | 0/? [00:00<?, ?it/s]

TO_PRINT_predictions[0] = tensor([17,  7,  8,  9, 12, 12,  3, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 12, 12, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0

Validation: |          | 0/? [00:00<?, ?it/s]

TO_PRINT_predictions[0] = tensor([17, 17,  7,  8,  9, 12,  3,  3, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 12, 12, 12, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0

Validation: |          | 0/? [00:00<?, ?it/s]

TO_PRINT_predictions[0] = tensor([17,  7,  8,  9,  9, 12,  0,  0, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 12, 12, 12, 12, 12, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0

Validation: |          | 0/? [00:00<?, ?it/s]

TO_PRINT_predictions[0] = tensor([17, 17,  7,  8,  9, 12,  0,  0, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0

# Model

# Training