In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm  
from utils import *
from models import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
if os.path.exists("test_batches.pt"):
    batches = torch.load("test_batches.pt")
else:
    batches = create_batches("test_sequences.txt", 64)
    torch.save(batches, "test_batches.pt")

vocab_size = 26
embedding_size = 10
hidden_size = 64
criterion = nn.NLLLoss()

Creazione batch: 100%|██████████| 21875/21875 [05:03<00:00, 72.09 batch/s]


In [4]:
path = "LSTM_weights.pth"
model = LSTM(vocab_size, embedding_size, hidden_size).to(device)
model.load_state_dict(torch.load(path))

total_loss = 0

input_sequences = []
predictions = []
count = 0

with torch.no_grad():
    with tqdm(batches) as tqdm_iterator:
        for sequences in tqdm_iterator:
            sequences = sequences.to(device)
            output, _ = model(sequences)
            output = output.permute(0, 2, 1)
            loss = criterion(output[:, :, :-1], sequences[:, 1:])
            if count < 10:
                  input_sequences.append(sequences.detach().cpu().numpy())
                  predictions.append(output.argmax(dim=1).detach().cpu().numpy())

            total_loss += loss.item()
            count += 1

            tqdm_iterator.set_postfix({"loss": loss.item()})
    
    avg_loss_Lstm = total_loss/len(batches)
    print(f'Average Loss: {avg_loss_Lstm}')

100%|██████████| 21875/21875 [02:22<00:00, 153.40it/s, loss=0.733]

Average Loss: 0.7920521784755162





In [5]:
for i in range(len(input_sequences)):
    print(f'Example {i+1}:')

    overlap_length = predictions[i].shape[1]
    input_overlap = input_sequences[i][:, 1:overlap_length]
    predictions_overlap = predictions[i][:, :overlap_length-1 ]
    correct_chars = np.sum(input_overlap == predictions_overlap)
    total_chars = input_overlap.size
    accuracy = correct_chars / total_chars

    accuracy = correct_chars / total_chars * 100
    print(f'Accuracy: {accuracy:.2f}% ({correct_chars}/{total_chars} characters)')

Example 1:
Accuracy: 73.47% (93991/127936 characters)
Example 2:
Accuracy: 73.71% (94307/127936 characters)
Example 3:
Accuracy: 73.37% (93873/127936 characters)
Example 4:
Accuracy: 74.00% (94674/127936 characters)
Example 5:
Accuracy: 73.37% (93867/127936 characters)
Example 6:
Accuracy: 73.46% (93976/127936 characters)
Example 7:
Accuracy: 74.12% (94832/127936 characters)
Example 8:
Accuracy: 73.31% (93792/127936 characters)
Example 9:
Accuracy: 73.59% (94146/127936 characters)
Example 10:
Accuracy: 73.59% (94144/127936 characters)


In [6]:
path = "biLSTM_weights.pth"
model = biLSTM(vocab_size, embedding_size, hidden_size).to(device)
model.load_state_dict(torch.load(path))

total_loss = 0
input_sequences = []
predictions = []
count = 0

with torch.no_grad():
    with tqdm(batches) as tqdm_iterator:
        for sequences in tqdm_iterator:
            sequences = sequences.to(device)
            output, _ = model(sequences)
            output = output.permute(0, 2, 1)
            loss = criterion(output[:, :, :-1], sequences[:, 1:])
            if count < 10:
                  input_sequences.append(sequences.detach().cpu().numpy())
                  predictions.append(output.argmax(dim=1).detach().cpu().numpy())

            total_loss += loss.item()
            count += 1

            tqdm_iterator.set_postfix({"loss": loss.item()})

avg_loss_biLstm = total_loss/len(batches)
print(f'Average Loss: {avg_loss_biLstm}')

  0%|          | 0/21875 [00:00<?, ?it/s]

100%|██████████| 21875/21875 [04:02<00:00, 90.38it/s, loss=1.52e-8]

Average Loss: 1.7180524523934894e-08





In [7]:
for i in range(len(input_sequences)):
    print(f'Example {i+1}:')

    overlap_length = predictions[i].shape[1]
    input_overlap = input_sequences[i][:, 1:overlap_length]
    predictions_overlap = predictions[i][:, :overlap_length-1 ]
    correct_chars = np.sum(input_overlap == predictions_overlap)
    total_chars = input_overlap.size
    accuracy = correct_chars / total_chars

    accuracy = correct_chars / total_chars * 100
    print(f'Accuracy: {accuracy:.2f}% ({correct_chars}/{total_chars} characters)')


Example 1:
Accuracy: 100.00% (127936/127936 characters)
Example 2:
Accuracy: 100.00% (127936/127936 characters)
Example 3:
Accuracy: 100.00% (127936/127936 characters)
Example 4:
Accuracy: 100.00% (127936/127936 characters)
Example 5:
Accuracy: 100.00% (127936/127936 characters)
Example 6:
Accuracy: 100.00% (127936/127936 characters)
Example 7:
Accuracy: 100.00% (127936/127936 characters)
Example 8:
Accuracy: 100.00% (127936/127936 characters)
Example 9:
Accuracy: 100.00% (127936/127936 characters)
Example 10:
Accuracy: 100.00% (127936/127936 characters)
