In [1]:
!pip install pytorch-crf



In [2]:
# BiLSTM - CRF Implementation

import torch
import torch.nn as nn
import torch.optim as optim
from torchcrf import CRF
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, tagset_size, padding_idx):
        super(BiLSTM_CRF, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, bidirectional=True, batch_first=True)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
        self.crf = CRF(tagset_size, batch_first=True)

    def forward(self, sentences):
        embeds = self.embedding(sentences)
        lstm_out, _ = self.lstm(embeds)
        lstm_out = self.hidden2tag(lstm_out)
        return lstm_out

    def neg_log_likelihood(self, sentences, tags):
        lstm_out = self.forward(sentences)
        return -self.crf(lstm_out, tags, mask=None, reduction='mean')

    def predict(self, sentences):
        lstm_out = self.forward(sentences)
        return self.crf.decode(lstm_out)

sentences = []
tags = []
with open("/content/train","r") as f:
    sentence,tag1 = [],[]
    for line in f.readlines():
        if(len(line.strip())==0):
            sentences.append(sentence)
            tags.append(tag1)
            tag1 = []
            sentence = []
        else:
            sent,tag = line.strip().split("	")
            sentence.append(sent)
            tag1.append(tag)

max_train = 500
sentences = sentences[:max_train]
tags = tags[:max_train]

word_vocab = set([word for sentence in sentences for word in sentence])
tag_vocab = set([tag for tag_list in tags for tag in tag_list])

word2idx = {word: idx+1 for idx, word in enumerate(word_vocab)}
word2idx['<PAD>'] = 0
tag2idx = {tag: idx for idx, tag in enumerate(tag_vocab)}

X_data = [[word2idx[word] for word in sentence] for sentence in sentences]
y_data = [[tag2idx[tag] for tag in tag_list] for tag_list in tags]

max_len = max(len(sentence) for sentence in X_data)
X_data = [sentence + [0] * (max_len - len(sentence)) for sentence in X_data]
y_data = [tag_list + [tag2idx['O']] * (max_len - len(tag_list)) for tag_list in y_data]

X_tensor = torch.tensor(X_data, dtype=torch.long)
y_tensor = torch.tensor(y_data, dtype=torch.long)

batch_size = 2
train_data = TensorDataset(X_tensor, y_tensor)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

embedding_dim = 100
hidden_dim = 256
tagset_size = len(tag2idx)
padding_idx = word2idx['<PAD>']

model = BiLSTM_CRF(vocab_size=len(word2idx), embedding_dim=embedding_dim, hidden_dim=hidden_dim, tagset_size=tagset_size, padding_idx=padding_idx)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 5

from tqdm import tqdm

for epoch in range(epochs):
    model.train()
    total_loss = 0
    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as batch_bar:
        for batch_sentences, batch_tags in batch_bar:
            batch_sentences = batch_sentences.to(device)
            batch_tags = batch_tags.to(device)

            model.zero_grad()
            loss = model.neg_log_likelihood(batch_sentences, batch_tags)
            total_loss += loss.item()

            loss.backward()
            optimizer.step()
            batch_bar.set_postfix(loss=total_loss / (batch_bar.n + 1))

    print(device)
    print(f"Epoch {epoch+1}/{epochs} - Total Loss: {total_loss:.4f}")

model.eval()

  score = torch.where(mask[i].unsqueeze(1), next_score, score)
Epoch 1/5: 100%|██████████| 250/250 [00:25<00:00,  9.97batch/s, loss=32.1]


cuda
Epoch 1/5 - Total Loss: 8017.7718


Epoch 2/5: 100%|██████████| 250/250 [00:18<00:00, 13.48batch/s, loss=13.5]


cuda
Epoch 2/5 - Total Loss: 3357.2539


Epoch 3/5: 100%|██████████| 250/250 [00:17<00:00, 14.00batch/s, loss=8.44]


cuda
Epoch 3/5 - Total Loss: 2101.1767


Epoch 4/5: 100%|██████████| 250/250 [00:18<00:00, 13.36batch/s, loss=4.97]


cuda
Epoch 4/5 - Total Loss: 1241.8988


Epoch 5/5: 100%|██████████| 250/250 [00:17<00:00, 13.96batch/s, loss=2.94]


cuda
Epoch 5/5 - Total Loss: 731.9922
Predictions for 'John is from New York': ['O', 'O', 'O', 'O', 'O']


In [3]:
import torch

def load_data(file_path):
    sentences, tags = [], []
    with open(file_path, "r") as f:
        sentence, tag1 = [], []
        for line in f.readlines():
            if len(line.strip()) == 0:
                sentences.append(sentence)
                tags.append(tag1)
                sentence, tag1 = [], []
            else:
                try:
                    sent, tag = line.strip().split("\t")
                except ValueError:
                    sent = line.strip()
                    tag = 'O'
                sentence.append(sent)
                tag1.append(tag)
    return sentences, tags

def predict_and_save(model, file_path, output_file, word2idx, tag2idx, max_len_train, device):
    sentences, actual_tags = load_data(file_path)
    X_data = [[word2idx.get(word, word2idx['<PAD>']) for word in sentence] for sentence in sentences]
    max_len_current = max(len(sentence) for sentence in X_data)
    max_len = max(max_len_train, max_len_current)

    X_data = [sentence + [0] * (max_len - len(sentence)) for sentence in X_data]
    X_tensor = torch.tensor(X_data, dtype=torch.long).to(device)

    model.eval()
    predictions = model.predict(X_tensor)
    idx2tag = {idx: tag for tag, idx in tag2idx.items()}
    predicted_tags = [[idx2tag[tag] for tag in pred] for pred in predictions]

    with open(output_file, "w") as f:
        for i in range(len(sentences)):
            pred_tags_i = predicted_tags[i][:len(sentences[i])]
            actual_tags_i = actual_tags[i][:len(sentences[i])]
            for word, pred_tag, actual_tag in zip(sentences[i], pred_tags_i, actual_tags_i):
                f.write(f"{word}\t{pred_tag}\n")
            f.write("\n")

predict_and_save(model, "/content/test", "/content/test_1_output", word2idx, tag2idx, max_len, device)
predict_and_save(model, "/content/dev", "/content/dev_1_output", word2idx, tag2idx, max_len, device)

print("Predictions saved to test.output and dev.output")

Predictions saved to test.output and dev.output


In [4]:
from sklearn.metrics import classification_report

def load_tags(pred_file, actual_file):
    """Loads predicted and actual tags from separate files."""
    predicted, actual = [], []

    with open(pred_file, "r") as f:
        pred = []
        for line in f:
            if line.strip():
                _, pred_tag = line.strip().split("\t")[:2]
                pred.append(pred_tag)
            else:
                predicted.append(pred)
                pred = []

    with open(actual_file, "r") as f:
        act = []
        for line in f:
            if line.strip():
                _, actual_tag = line.strip().split("\t")[:2]
                act.append(actual_tag)
            else:
                actual.append(act)
                act = []

    return actual, predicted

def evaluate(pred_file, actual_file):
    actual, predicted = load_tags(pred_file, actual_file)

    actual_flat = [tag for sentence in actual for tag in sentence]
    predicted_flat = [tag for sentence in predicted for tag in sentence]

    print(f"Evaluation for {pred_file}:")
    print(classification_report(actual_flat, predicted_flat, digits=4))

evaluate("/content/test_1_output", "test.answers")
evaluate("/content/dev_1_output", "dev.answers")

Evaluation for /content/dev_1_output:
              precision    recall  f1-score   support

       B-DNA     0.1538    0.3333    0.2105         6
       B-RNA     0.0000    0.0000    0.0000        15
 B-cell_line     0.1429    0.5000    0.2222         2
 B-cell_type     0.7200    0.3214    0.4444        56
   B-protein     0.4253    0.2984    0.3507       124
       I-DNA     0.3889    0.7778    0.5185         9
       I-RNA     0.5000    0.0769    0.1333        26
 I-cell_line     0.1364    0.5000    0.2143         6
 I-cell_type     0.7955    0.4070    0.5385        86
   I-protein     0.7119    0.2781    0.4000       151
           O     0.8908    0.9677    0.9277      2319

    accuracy                         0.8539      2800
   macro avg     0.4423    0.4055    0.3600      2800
weighted avg     0.8405    0.8539    0.8348      2800



In [1]:
# Character level CNN implementation

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchcrf import CRF
from torch.utils.data import DataLoader, TensorDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class CharCNNEmbedding(nn.Module):
    def __init__(self, char_vocab_size, char_embedding_dim, char_kernel_size, char_filters):
        super(CharCNNEmbedding, self).__init__()
        self.char_embedding = nn.Embedding(char_vocab_size, char_embedding_dim, padding_idx=0)
        self.conv = nn.Conv1d(char_embedding_dim, char_filters, kernel_size=char_kernel_size)
        self.maxpool = nn.AdaptiveMaxPool1d(1)

    def forward(self, x):
        batch_size, seq_len, word_len = x.size()
        x = self.char_embedding(x).view(batch_size * seq_len, word_len, -1).permute(0, 2, 1)
        x = F.relu(self.conv(x))
        x = self.maxpool(x).squeeze(2)
        return x.view(batch_size, seq_len, -1)

class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, tagset_size, padding_idx, char_vocab_size, char_embedding_dim, char_hidden_dim):
        super(BiLSTM_CRF, self).__init__()
        self.word_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.char_embedding = CharCNNEmbedding(char_vocab_size, char_embedding_dim, 3, char_hidden_dim)
        self.lstm = nn.LSTM(embedding_dim + char_hidden_dim, hidden_dim // 2, bidirectional=True, batch_first=True)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
        self.crf = CRF(tagset_size, batch_first=True)

    def forward(self, sentences, char_inputs):
        word_embeds = self.word_embedding(sentences)
        char_embeds = self.char_embedding(char_inputs)
        combined_embeds = torch.cat([word_embeds, char_embeds], dim=-1)
        lstm_out, _ = self.lstm(combined_embeds)
        tag_scores = self.hidden2tag(lstm_out)
        return tag_scores

    def neg_log_likelihood(self, sentences, tags, char_inputs):
        emissions = self.forward(sentences, char_inputs)
        mask = sentences != 0
        return -self.crf(emissions, tags, mask=mask, reduction='mean')

    def predict(self, sentences, char_inputs):
        emissions = self.forward(sentences, char_inputs)
        return self.crf.decode(emissions)


sentences = []
tags = []
with open("/content/train","r") as f:
    sentence,tag1 = [],[]
    for line in f.readlines():
        if(len(line.strip())==0):
            sentences.append(sentence)
            tags.append(tag1)
            tag1 = []
            sentence = []
        else:
            sent,tag = line.strip().split("	")
            sentence.append(sent)
            tag1.append(tag)

max_train = 500
sentences = sentences[:max_train]
tags = tags[:max_train]

word_vocab = {word: idx + 1 for idx, word in enumerate(set(word for sent in sentences for word in sent))}
word_vocab['<PAD>'] = 0

tag_vocab = {tag: idx for idx, tag in enumerate(set(tag for tag_seq in tags for tag in tag_seq))}

char_vocab = {char: idx + 1 for idx, char in enumerate(set(char for word in word_vocab for char in word))}
char_vocab['<PAD>'] = 0

X_data = [[word_vocab[word] for word in sentence] for sentence in sentences]
y_data = [[tag_vocab[tag] for tag in tag_seq] for tag_seq in tags]

char_data = [[[char_vocab[char] for char in word] for word in sentence] for sentence in sentences]

max_word_len = max(len(word) for sentence in char_data for word in sentence)
max_sentence_len = max(len(sentence) for sentence in X_data)

X_data = [sentence + [0] * (max_sentence_len - len(sentence)) for sentence in X_data]
y_data = [tag_seq + [0] * (max_sentence_len - len(tag_seq)) for tag_seq in y_data]
char_data = [[word + [0] * (max_word_len - len(word)) for word in sentence] for sentence in char_data]
char_data = [sentence + [[0] * max_word_len] * (max_sentence_len - len(sentence)) for sentence in char_data]

X_tensor = torch.tensor(X_data, dtype=torch.long)
y_tensor = torch.tensor(y_data, dtype=torch.long)
char_tensor = torch.tensor(char_data, dtype=torch.long)

X_tensor = X_tensor.to(device)
y_tensor = y_tensor.to(device)
char_tensor = char_tensor.to(device)

vocab_size = len(word_vocab)
char_vocab_size = len(char_vocab)
embedding_dim = 100
char_embedding_dim = 30
char_hidden_dim = 50
hidden_dim = 256
tagset_size = len(tag_vocab)
padding_idx = word_vocab['<PAD>']

model = BiLSTM_CRF(vocab_size, embedding_dim, hidden_dim, tagset_size, padding_idx, char_vocab_size, char_embedding_dim, char_hidden_dim)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
optimizer.zero_grad()
loss = model.neg_log_likelihood(X_tensor, y_tensor, char_tensor)
loss.backward()
optimizer.step()

model.eval()

predictions = model.predict(test_sentence_tensor, test_char_tensor)
predicted_tags = [list(tag_vocab.keys())[tag] for tag in predictions[0]]
print(predicted_tags)

['B-protein', 'O', 'B-DNA', 'B-cell_type', 'O']


  score = torch.where(mask[i].unsqueeze(1), next_score, score)


In [3]:
import torch

def load_data(file_path):
    sentences, tags = [], []
    with open(file_path, "r") as f:
        sentence, tag1 = [], []
        for line in f.readlines():
            if len(line.strip()) == 0:
                sentences.append(sentence)
                tags.append(tag1)
                sentence, tag1 = [], []
            else:
                try:
                    sent, tag = line.strip().split("\t")
                except ValueError:
                    sent = line.strip()
                    tag = 'O'
                sentence.append(sent)
                tag1.append(tag)
    return sentences, tags

def predict_and_save(model, file_path, output_file, word2idx, tag2idx, char2idx, max_word_len, max_len_train, device, batch_size=8):
    sentences, actual_tags = load_data(file_path)

    X_data = [[word2idx.get(word, 0) for word in sentence] for sentence in sentences]
    max_len_current = max(len(sentence) for sentence in X_data)
    max_len = max(max_len_train, max_len_current)

    char_data = [[[char2idx.get(char, 0) for char in word] for word in sentence] for sentence in sentences]
    char_data = [[word + [0] * (max_word_len - len(word)) for word in sentence] for sentence in char_data]
    char_data = [sentence + [[0] * max_word_len] * (max_len - len(sentence)) for sentence in char_data]

    X_data = [sentence + [0] * (max_len - len(sentence)) for sentence in X_data]

    model.eval()
    all_predictions = []

    with torch.no_grad():
        for i in range(0, len(X_data), batch_size):
            batch_X = X_data[i:min(i+batch_size, len(X_data))]
            batch_char = char_data[i:min(i+batch_size, len(char_data))]

            X_tensor = torch.tensor(batch_X, dtype=torch.long).to(device)
            char_tensor = torch.tensor(batch_char, dtype=torch.long).to(device)

            batch_predictions = model.predict(X_tensor, char_tensor)
            all_predictions.extend(batch_predictions)

            del X_tensor, char_tensor
            torch.cuda.empty_cache()

    idx2tag = {idx: tag for tag, idx in tag2idx.items()}
    predicted_tags = [[idx2tag[tag] for tag in pred] for pred in all_predictions]

    with open(output_file, "w") as f:
        for i in range(len(sentences)):
            pred_tags_i = predicted_tags[i][:len(sentences[i])]
            actual_tags_i = actual_tags[i][:len(sentences[i])]
            for word, pred_tag, actual_tag in zip(sentences[i], pred_tags_i, actual_tags_i):
                f.write(f"{word}\t{pred_tag}\n")
            f.write("\n")

batch_size = 16

predict_and_save(model, "/content/test", "/content/test_2_output", word_vocab, tag_vocab, char_vocab, max_word_len, max_sentence_len, device, batch_size)
predict_and_save(model, "/content/dev", "/content/dev_2_output", word_vocab, tag_vocab, char_vocab, max_word_len, max_sentence_len, device, batch_size)

print("Predictions saved to test_2_output and dev_2_output")

Predictions saved to test_2_output and dev_2_output


In [5]:
from sklearn.metrics import classification_report

def load_tags(pred_file, actual_file):
    predicted, actual = [], []

    with open(pred_file, "r") as f:
        pred = []
        for line in f:
            if line.strip():
                _, pred_tag = line.strip().split("\t")[:2]
                pred.append(pred_tag)
            else:
                predicted.append(pred)
                pred = []

    with open(actual_file, "r") as f:
        act = []
        for line in f:
            if line.strip():
                _, actual_tag = line.strip().split("\t")[:2]
                act.append(actual_tag)
            else:
                actual.append(act)
                act = []

    return actual, predicted

def evaluate(pred_file, actual_file):
    actual, predicted = load_tags(pred_file, actual_file)
    predicted = predicted[:len(actual)]
    print(len(actual))
    print(len(predicted))
    actual_flat = [tag for sentence in actual for tag in sentence]
    predicted_flat = [tag for sentence in predicted for tag in sentence]

    print(f"Evaluation for {pred_file}:")
    print(classification_report(actual_flat, predicted_flat, digits=4))

evaluate("/content/test_2_output", "test.answers")
evaluate("/content/dev_2_output", "dev.answers")

3855
3855
Evaluation for /content/test_2_output:
              precision    recall  f1-score   support

       B-DNA     0.0047    0.0502    0.0085      1056
       B-RNA     0.0000    0.0000    0.0000       118
 B-cell_line     0.0026    0.0020    0.0023       497
 B-cell_type     0.0196    0.1615    0.0350      1919
   B-protein     0.0871    0.2041    0.1220      5067
       I-DNA     0.0146    0.0056    0.0081      1789
       I-RNA     0.0019    0.0053    0.0028       187
 I-cell_line     0.0050    0.0020    0.0029       983
 I-cell_type     0.0097    0.0010    0.0018      2987
   I-protein     0.0575    0.0172    0.0265      4774
           O     0.8047    0.5635    0.6628     81615

    accuracy                         0.4702    100992
   macro avg     0.0916    0.0920    0.0793    100992
weighted avg     0.6584    0.4702    0.5440    100992

105
105
Evaluation for /content/dev_2_output:
              precision    recall  f1-score   support

       B-DNA     0.0000    0.0000    

In [15]:
actual, predicted = load_tags("/content/test_2_output", "test.answers")

In [13]:
# BiLSTM CRF and CNN Implementation

import torch
import torch.nn as nn
import torch.optim as optim
from torchcrf import CRF
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, tagset_size, padding_idx):
        super(BiLSTM_CRF, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, bidirectional=True, batch_first=True)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
        self.crf = CRF(tagset_size, batch_first=True)

    def forward(self, sentences):
        word_embeds = self.embedding(sentences)
        lstm_out, _ = self.lstm(word_embeds)
        lstm_out = self.hidden2tag(lstm_out)
        return lstm_out

    def hamming_loss(self, y_pred, y_true):
        return (y_pred != y_true).float().mean()

    def neg_log_likelihood(self, sentences, tags, loss_type="softmax_margin"):
        lstm_out = self.forward(sentences)
        crf_loss = -self.crf(lstm_out, tags, mask=None, reduction='mean')
        print(f"CRF Loss: {crf_loss.item()}")
        predicted_tags = self.crf.decode(lstm_out)
        predicted_tags = torch.tensor(predicted_tags, dtype=torch.long).to(tags.device)
        hamming_cost = self.hamming_loss(predicted_tags, tags)

        if loss_type == "max_margin":
            margin_loss = torch.max(crf_loss - hamming_cost, torch.tensor(0.0).to(tags.device))
            return margin_loss.mean()

        elif loss_type == "softmax_margin":
            log_sum_exp = torch.logsumexp(lstm_out, dim=2)
            softmax_margin_loss = torch.mean(log_sum_exp - crf_loss + hamming_cost)
            return softmax_margin_loss

        elif loss_type == "ramp_loss":
            ramp_loss = torch.max(crf_loss, torch.tensor(0.0).to(tags.device)) + hamming_cost
            return ramp_loss.mean()

        elif loss_type == "soft_ramp_loss":
            soft_ramp_loss = torch.mean(torch.logsumexp(lstm_out, dim=2) + hamming_cost)
            return soft_ramp_loss

        else:
            return crf_loss


    def predict(self, sentences):
        lstm_out = self.forward(sentences)
        return self.crf.decode(lstm_out)



sentences = []
tags = []
with open("/content/train","r") as f:
    sentence,tag1 = [],[]
    for line in f.readlines():
        if(len(line.strip())==0):
            sentences.append(sentence)
            tags.append(tag1)
            tag1 = []
            sentence = []
        else:
            sent,tag = line.strip().split("	")
            sentence.append(sent)
            tag1.append(tag)

max_train = 500
sentences = sentences[:max_train]
tags = tags[:max_train]

word_vocab = set([word for sentence in sentences for word in sentence])
tag_vocab = set([tag for tag_list in tags for tag in tag_list])

word2idx = {word: idx+1 for idx, word in enumerate(word_vocab)}
word2idx['<PAD>'] = 0
tag2idx = {tag: idx for idx, tag in enumerate(tag_vocab)}

X_data = [[word2idx[word] for word in sentence] for sentence in sentences]
y_data = [[tag2idx[tag] for tag in tag_list] for tag_list in tags]

max_len = max(len(sentence) for sentence in X_data)
X_data = [sentence + [0] * (max_len - len(sentence)) for sentence in X_data]
y_data = [tag_list + [tag2idx['O']] * (max_len - len(tag_list)) for tag_list in y_data]

X_tensor = torch.tensor(X_data, dtype=torch.long)
y_tensor = torch.tensor(y_data, dtype=torch.long)

batch_size = 16
train_data = TensorDataset(X_tensor, y_tensor)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

embedding_dim = 100
hidden_dim = 256
tagset_size = len(tag2idx)
padding_idx = word2idx['<PAD>']

model = BiLSTM_CRF(vocab_size=len(word2idx), embedding_dim=embedding_dim, hidden_dim=hidden_dim, tagset_size=tagset_size, padding_idx=padding_idx)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 20

from tqdm import tqdm

for epoch in range(epochs):
    model.train()
    total_loss = 0

    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as batch_bar:
        for batch_sentences, batch_tags in batch_bar:
            batch_sentences = batch_sentences.to(device)
            batch_tags = batch_tags.to(device)

            model.zero_grad()

            loss = model.neg_log_likelihood(batch_sentences, batch_tags,loss_type="")
            total_loss += loss.item()

            loss.backward()
            optimizer.step()
            batch_bar.set_postfix(loss=total_loss / (batch_bar.n + 1))

    print(device)
    print(f"Epoch {epoch+1}/{epochs} - Total Loss: {total_loss:.4f}")

model.eval()

print(f"Predictions for '{' '.join(test_sentence)}': {predicted_tags}")

Epoch 1/20:   0%|          | 0/32 [00:00<?, ?batch/s]

CRF Loss: 0.0499725341796875


Epoch 1/20:   6%|▋         | 2/32 [00:00<00:07,  3.84batch/s, loss=0.0702]

CRF Loss: 0.09042739868164062
CRF Loss: 0.06840133666992188


Epoch 1/20:  12%|█▎        | 4/32 [00:00<00:05,  5.52batch/s, loss=0.0675]

CRF Loss: 0.06139373779296875
CRF Loss: 0.208465576171875


Epoch 1/20:  19%|█▉        | 6/32 [00:01<00:03,  6.62batch/s, loss=0.0913]

CRF Loss: 0.06934356689453125
CRF Loss: 0.06287765502929688


Epoch 1/20:  25%|██▌       | 8/32 [00:01<00:03,  6.87batch/s, loss=0.0866]

CRF Loss: 0.08157730102539062
CRF Loss: 0.0948638916015625


Epoch 1/20:  31%|███▏      | 10/32 [00:01<00:02,  7.35batch/s, loss=0.0865]

CRF Loss: 0.0771942138671875
CRF Loss: 0.06711959838867188


Epoch 1/20:  38%|███▊      | 12/32 [00:01<00:02,  7.13batch/s, loss=0.0848]

CRF Loss: 0.0857696533203125
CRF Loss: 0.057895660400390625


Epoch 1/20:  44%|████▍     | 14/32 [00:02<00:02,  6.61batch/s, loss=0.0829]

CRF Loss: 0.08462142944335938
CRF Loss: 0.050212860107421875


Epoch 1/20:  50%|█████     | 16/32 [00:02<00:02,  6.89batch/s, loss=0.0776]

CRF Loss: 0.053371429443359375
CRF Loss: 0.05614471435546875
CRF Loss: 0.058971405029296875


Epoch 1/20:  59%|█████▉    | 19/32 [00:02<00:01,  8.52batch/s, loss=0.0763]

CRF Loss: 0.07239532470703125
CRF Loss: 0.07413482666015625
CRF Loss: 0.06475830078125


Epoch 1/20:  72%|███████▏  | 23/32 [00:03<00:00,  9.56batch/s, loss=0.0788]

CRF Loss: 0.08334732055664062
CRF Loss: 0.059963226318359375
CRF Loss: 0.06428909301757812


Epoch 1/20:  78%|███████▊  | 25/32 [00:03<00:00,  9.41batch/s, loss=0.0737]

CRF Loss: 0.051334381103515625
CRF Loss: 0.06818389892578125
CRF Loss: 0.06565475463867188


Epoch 1/20:  91%|█████████ | 29/32 [00:03<00:00,  9.86batch/s, loss=0.0767]

CRF Loss: 0.05376434326171875
CRF Loss: 0.112030029296875
CRF Loss: 0.051845550537109375


Epoch 1/20: 100%|██████████| 32/32 [00:04<00:00,  7.77batch/s, loss=0.072]


CRF Loss: 0.06633377075195312
CRF Loss: 0.0378875732421875
cuda
Epoch 1/20 - Total Loss: 2.3045


Epoch 2/20:   0%|          | 0/32 [00:00<?, ?batch/s]

CRF Loss: 0.045017242431640625


Epoch 2/20:   6%|▋         | 2/32 [00:00<00:02, 10.31batch/s, loss=0.0808]

CRF Loss: 0.03577423095703125


Epoch 2/20:   6%|▋         | 2/32 [00:00<00:02, 10.31batch/s, loss=0.0431]

CRF Loss: 0.048419952392578125


Epoch 2/20:  12%|█▎        | 4/32 [00:00<00:02,  9.74batch/s, loss=0.0555]

CRF Loss: 0.0372772216796875
CRF Loss: 0.02689361572265625


Epoch 2/20:  12%|█▎        | 4/32 [00:00<00:02,  9.74batch/s, loss=0.0387]

CRF Loss: 0.04083251953125


Epoch 2/20:  19%|█▉        | 6/32 [00:00<00:02,  9.95batch/s, loss=0.0406]

CRF Loss: 0.049922943115234375
CRF Loss: 0.032985687255859375


Epoch 2/20:  25%|██▌       | 8/32 [00:00<00:02, 10.06batch/s, loss=0.0453]

CRF Loss: 0.02446746826171875


Epoch 2/20:  31%|███▏      | 10/32 [00:00<00:02, 10.20batch/s, loss=0.0413]

CRF Loss: 0.0300445556640625
CRF Loss: 0.028263092041015625


Epoch 2/20:  31%|███▏      | 10/32 [00:01<00:02, 10.20batch/s, loss=0.0364]

CRF Loss: 0.028224945068359375


Epoch 2/20:  38%|███▊      | 12/32 [00:01<00:01, 10.22batch/s, loss=0.0354]

CRF Loss: 0.032070159912109375


Epoch 2/20:  44%|████▍     | 14/32 [00:01<00:01,  9.92batch/s, loss=0.0379]

CRF Loss: 0.031864166259765625


Epoch 2/20:  47%|████▋     | 15/32 [00:01<00:01,  9.91batch/s, loss=0.0341]

CRF Loss: 0.019256591796875
CRF Loss: 0.03340911865234375


Epoch 2/20:  47%|████▋     | 15/32 [00:01<00:01,  9.91batch/s, loss=0.034] 

CRF Loss: 0.022846221923828125


Epoch 2/20:  53%|█████▎    | 17/32 [00:01<00:01, 10.03batch/s, loss=0.0332]

CRF Loss: 0.030178070068359375
CRF Loss: 0.031951904296875


Epoch 2/20:  59%|█████▉    | 19/32 [00:01<00:01, 10.09batch/s, loss=0.035]

CRF Loss: 0.06537246704101562


Epoch 2/20:  66%|██████▌   | 21/32 [00:02<00:01, 10.18batch/s, loss=0.0363]

CRF Loss: 0.031566619873046875
CRF Loss: 0.049350738525390625


Epoch 2/20:  72%|███████▏  | 23/32 [00:02<00:00,  9.84batch/s, loss=0.0342]

CRF Loss: 0.02104949951171875
CRF Loss: 0.023193359375


Epoch 2/20:  78%|███████▊  | 25/32 [00:02<00:00, 10.00batch/s, loss=0.034] 

CRF Loss: 0.02370452880859375
CRF Loss: 0.03878021240234375
CRF Loss: 0.032886505126953125


Epoch 2/20:  91%|█████████ | 29/32 [00:02<00:00, 10.22batch/s, loss=0.0351]

CRF Loss: 0.0442047119140625
CRF Loss: 0.024089813232421875
CRF Loss: 0.03572845458984375


Epoch 2/20: 100%|██████████| 32/32 [00:03<00:00, 10.12batch/s, loss=0.0337]


CRF Loss: 0.0292205810546875
CRF Loss: 0.02825927734375
cuda
Epoch 2/20 - Total Loss: 1.0771


Epoch 3/20:   0%|          | 0/32 [00:00<?, ?batch/s]

CRF Loss: 0.0384979248046875


Epoch 3/20:   6%|▋         | 2/32 [00:00<00:03,  9.27batch/s, loss=0.0358]

CRF Loss: 0.03301239013671875


Epoch 3/20:   6%|▋         | 2/32 [00:00<00:03,  9.27batch/s, loss=0.03]  

CRF Loss: 0.018520355224609375
CRF Loss: 0.0246124267578125


Epoch 3/20:  12%|█▎        | 4/32 [00:00<00:02,  9.81batch/s, loss=0.0382]

CRF Loss: 0.017238616943359375


Epoch 3/20:  19%|█▉        | 6/32 [00:00<00:02, 10.02batch/s, loss=0.0312]

CRF Loss: 0.024242401123046875
CRF Loss: 0.01403045654296875


Epoch 3/20:  19%|█▉        | 6/32 [00:00<00:02, 10.02batch/s, loss=0.0243]

CRF Loss: 0.021121978759765625


Epoch 3/20:  25%|██▌       | 8/32 [00:00<00:02, 10.16batch/s, loss=0.024] 

CRF Loss: 0.025089263916015625
CRF Loss: 0.019756317138671875


Epoch 3/20:  31%|███▏      | 10/32 [00:00<00:02, 10.21batch/s, loss=0.0262]

CRF Loss: 0.0256195068359375


Epoch 3/20:  38%|███▊      | 12/32 [00:01<00:02,  9.91batch/s, loss=0.0262]

CRF Loss: 0.025943756103515625


Epoch 3/20:  38%|███▊      | 12/32 [00:01<00:02,  9.91batch/s, loss=0.0233]

CRF Loss: 0.015285491943359375
CRF Loss: 0.027667999267578125


Epoch 3/20:  44%|████▍     | 14/32 [00:01<00:01, 10.05batch/s, loss=0.0254]

CRF Loss: 0.026386260986328125


Epoch 3/20:  50%|█████     | 16/32 [00:01<00:01, 10.09batch/s, loss=0.0261]

CRF Loss: 0.03459930419921875
CRF Loss: 0.0272064208984375


Epoch 3/20:  50%|█████     | 16/32 [00:01<00:01, 10.09batch/s, loss=0.0246]

CRF Loss: 0.017185211181640625


Epoch 3/20:  56%|█████▋    | 18/32 [00:01<00:01, 10.19batch/s, loss=0.0239]

CRF Loss: 0.01720428466796875
CRF Loss: 0.013607025146484375


Epoch 3/20:  62%|██████▎   | 20/32 [00:01<00:01, 10.25batch/s, loss=0.0246]

CRF Loss: 0.014339447021484375


Epoch 3/20:  69%|██████▉   | 22/32 [00:02<00:01,  9.95batch/s, loss=0.024]

CRF Loss: 0.02306365966796875


Epoch 3/20:  69%|██████▉   | 22/32 [00:02<00:01,  9.95batch/s, loss=0.0229]

CRF Loss: 0.023193359375


Epoch 3/20:  75%|███████▌  | 24/32 [00:02<00:00, 10.02batch/s, loss=0.024]

CRF Loss: 0.02359771728515625
CRF Loss: 0.0228424072265625


Epoch 3/20:  75%|███████▌  | 24/32 [00:02<00:00, 10.02batch/s, loss=0.023]

CRF Loss: 0.017917633056640625


Epoch 3/20:  81%|████████▏ | 26/32 [00:02<00:00, 10.02batch/s, loss=0.0228]

CRF Loss: 0.02447509765625
CRF Loss: 0.014888763427734375


Epoch 3/20:  91%|█████████ | 29/32 [00:02<00:00,  9.96batch/s, loss=0.0223]

CRF Loss: 0.01798248291015625
CRF Loss: 0.018871307373046875
CRF Loss: 0.0194549560546875


Epoch 3/20: 100%|██████████| 32/32 [00:03<00:00, 10.00batch/s, loss=0.0218]


CRF Loss: 0.0093231201171875
cuda
Epoch 3/20 - Total Loss: 0.6968


Epoch 4/20:   0%|          | 0/32 [00:00<?, ?batch/s, loss=0.0207]

CRF Loss: 0.02074432373046875
CRF Loss: 0.026950836181640625


Epoch 4/20:   6%|▋         | 2/32 [00:00<00:02, 10.20batch/s, loss=0.0477]

CRF Loss: 0.01348114013671875


Epoch 4/20:  12%|█▎        | 4/32 [00:00<00:02, 10.16batch/s, loss=0.0282]

CRF Loss: 0.023326873779296875
CRF Loss: 0.015077590942382812


Epoch 4/20:  12%|█▎        | 4/32 [00:00<00:02, 10.16batch/s, loss=0.0199]

CRF Loss: 0.01819610595703125


Epoch 4/20:  19%|█▉        | 6/32 [00:00<00:02, 10.24batch/s, loss=0.0185]

CRF Loss: 0.011810302734375
CRF Loss: 0.014171600341796875


Epoch 4/20:  25%|██▌       | 8/32 [00:00<00:02, 10.25batch/s, loss=0.0205]

CRF Loss: 0.015689849853515625


Epoch 4/20:  31%|███▏      | 10/32 [00:00<00:02, 10.10batch/s, loss=0.0195]

CRF Loss: 0.016063690185546875


Epoch 4/20:  31%|███▏      | 10/32 [00:01<00:02, 10.10batch/s, loss=0.0179]

CRF Loss: 0.0213775634765625
CRF Loss: 0.0142822265625


Epoch 4/20:  38%|███▊      | 12/32 [00:01<00:01, 10.01batch/s, loss=0.0192]

CRF Loss: 0.015392303466796875


Epoch 4/20:  44%|████▍     | 14/32 [00:01<00:01, 10.08batch/s, loss=0.0186]

CRF Loss: 0.01512908935546875
CRF Loss: 0.0162506103515625


Epoch 4/20:  44%|████▍     | 14/32 [00:01<00:01, 10.08batch/s, loss=0.0172]

CRF Loss: 0.011188507080078125


Epoch 4/20:  50%|█████     | 16/32 [00:01<00:01, 10.11batch/s, loss=0.0166]

CRF Loss: 0.013660430908203125
CRF Loss: 0.012445449829101562


Epoch 4/20:  56%|█████▋    | 18/32 [00:01<00:01, 10.24batch/s, loss=0.0174]

CRF Loss: 0.015918731689453125


Epoch 4/20:  62%|██████▎   | 20/32 [00:01<00:01, 10.11batch/s, loss=0.0171]

CRF Loss: 0.013957977294921875
CRF Loss: 0.014003753662109375


Epoch 4/20:  62%|██████▎   | 20/32 [00:02<00:01, 10.11batch/s, loss=0.0161]

CRF Loss: 0.013843536376953125


Epoch 4/20:  69%|██████▉   | 22/32 [00:02<00:01,  8.59batch/s, loss=0.0168]

CRF Loss: 0.01848602294921875


Epoch 4/20:  72%|███████▏  | 23/32 [00:02<00:01,  8.33batch/s, loss=0.0161]

CRF Loss: 0.018383026123046875


Epoch 4/20:  75%|███████▌  | 24/32 [00:02<00:00,  8.20batch/s, loss=0.0162]

CRF Loss: 0.01569366455078125


Epoch 4/20:  78%|███████▊  | 25/32 [00:02<00:00,  8.06batch/s, loss=0.0162]

CRF Loss: 0.0159759521484375


Epoch 4/20:  81%|████████▏ | 26/32 [00:02<00:00,  7.90batch/s, loss=0.0162]

CRF Loss: 0.014583587646484375


Epoch 4/20:  84%|████████▍ | 27/32 [00:02<00:00,  7.70batch/s, loss=0.0162]

CRF Loss: 0.01256561279296875


Epoch 4/20:  88%|████████▊ | 28/32 [00:03<00:00,  7.74batch/s, loss=0.016]

CRF Loss: 0.012546539306640625


Epoch 4/20:  91%|█████████ | 29/32 [00:03<00:00,  7.71batch/s, loss=0.0159]

CRF Loss: 0.016099929809570312


Epoch 4/20:  94%|█████████▍| 30/32 [00:03<00:00,  7.65batch/s, loss=0.0159]

CRF Loss: 0.013147354125976562


Epoch 4/20:  97%|█████████▋| 31/32 [00:03<00:00,  7.18batch/s, loss=0.0158]

CRF Loss: 0.0226287841796875


Epoch 4/20: 100%|██████████| 32/32 [00:03<00:00,  8.80batch/s, loss=0.016]


cuda
Epoch 4/20 - Total Loss: 0.5131


Epoch 5/20:   0%|          | 0/32 [00:00<?, ?batch/s]

CRF Loss: 0.013467788696289062


Epoch 5/20:   3%|▎         | 1/32 [00:00<00:05,  6.12batch/s, loss=0.0135]

CRF Loss: 0.01308441162109375


Epoch 5/20:   6%|▋         | 2/32 [00:00<00:05,  5.84batch/s, loss=0.0133]

CRF Loss: 0.010370254516601562


Epoch 5/20:   9%|▉         | 3/32 [00:00<00:04,  6.06batch/s, loss=0.0123]

CRF Loss: 0.007970809936523438


Epoch 5/20:  12%|█▎        | 4/32 [00:00<00:04,  6.85batch/s, loss=0.0123]

CRF Loss: 0.016414642333984375


Epoch 5/20:  19%|█▉        | 6/32 [00:00<00:03,  8.23batch/s, loss=0.0151]

CRF Loss: 0.01398468017578125
CRF Loss: 0.01284027099609375


Epoch 5/20:  25%|██▌       | 8/32 [00:01<00:02,  8.41batch/s, loss=0.0125]

CRF Loss: 0.01088714599609375
CRF Loss: 0.013525009155273438
CRF Loss: 0.014020919799804688


Epoch 5/20:  38%|███▊      | 12/32 [00:01<00:02,  9.43batch/s, loss=0.0138]

CRF Loss: 0.0117645263671875
CRF Loss: 0.01317596435546875
CRF Loss: 0.010059356689453125


Epoch 5/20:  44%|████▍     | 14/32 [00:01<00:01,  9.47batch/s, loss=0.0121]

CRF Loss: 0.012248992919921875
CRF Loss: 0.00782012939453125
CRF Loss: 0.013355255126953125


Epoch 5/20:  56%|█████▋    | 18/32 [00:02<00:01,  9.95batch/s, loss=0.0134]

CRF Loss: 0.01505279541015625
CRF Loss: 0.017812728881835938
CRF Loss: 0.012845993041992188


Epoch 5/20:  62%|██████▎   | 20/32 [00:02<00:01,  9.91batch/s, loss=0.0129]

CRF Loss: 0.016786575317382812
CRF Loss: 0.012889862060546875
CRF Loss: 0.011915206909179688


Epoch 5/20:  75%|███████▌  | 24/32 [00:02<00:00,  9.94batch/s, loss=0.0133]

CRF Loss: 0.011354446411132812
CRF Loss: 0.012790679931640625


Epoch 5/20:  78%|███████▊  | 25/32 [00:02<00:00,  9.93batch/s, loss=0.0126]

CRF Loss: 0.01139068603515625
CRF Loss: 0.010015487670898438
CRF Loss: 0.013416290283203125


Epoch 5/20:  91%|█████████ | 29/32 [00:03<00:00, 10.23batch/s, loss=0.0129]

CRF Loss: 0.00711822509765625
CRF Loss: 0.012083053588867188
CRF Loss: 0.010068893432617188


Epoch 5/20: 100%|██████████| 32/32 [00:03<00:00,  9.32batch/s, loss=0.0121]


CRF Loss: 0.011766433715820312
CRF Loss: 0.00439453125
cuda
Epoch 5/20 - Total Loss: 0.3867


Epoch 6/20:   0%|          | 0/32 [00:00<?, ?batch/s]

CRF Loss: 0.010829925537109375


Epoch 6/20:   6%|▋         | 2/32 [00:00<00:03,  9.77batch/s, loss=0.0214]

CRF Loss: 0.010591506958007812


Epoch 6/20:   9%|▉         | 3/32 [00:00<00:02,  9.74batch/s, loss=0.0104]

CRF Loss: 0.009695053100585938


Epoch 6/20:  12%|█▎        | 4/32 [00:00<00:02,  9.80batch/s, loss=0.0116]

CRF Loss: 0.01512908935546875


Epoch 6/20:  12%|█▎        | 4/32 [00:00<00:02,  9.80batch/s, loss=0.0118]

CRF Loss: 0.01297760009765625
CRF Loss: 0.012025833129882812


Epoch 6/20:  19%|█▉        | 6/32 [00:00<00:02, 10.04batch/s, loss=0.0142]

CRF Loss: 0.00746917724609375


Epoch 6/20:  25%|██▌       | 8/32 [00:00<00:02,  9.81batch/s, loss=0.0125]

CRF Loss: 0.008863449096679688


Epoch 6/20:  25%|██▌       | 8/32 [00:00<00:02,  9.81batch/s, loss=0.0103]

CRF Loss: 0.0052890777587890625


Epoch 6/20:  31%|███▏      | 10/32 [00:01<00:02,  9.95batch/s, loss=0.0117]

CRF Loss: 0.01271820068359375
CRF Loss: 0.008687973022460938


Epoch 6/20:  31%|███▏      | 10/32 [00:01<00:02,  9.95batch/s, loss=0.0104]

CRF Loss: 0.00971221923828125


Epoch 6/20:  38%|███▊      | 12/32 [00:01<00:02,  9.85batch/s, loss=0.0103]

CRF Loss: 0.01023101806640625


Epoch 6/20:  44%|████▍     | 14/32 [00:01<00:01,  9.97batch/s, loss=0.011]

CRF Loss: 0.009267807006835938
CRF Loss: 0.008737564086914062


Epoch 6/20:  44%|████▍     | 14/32 [00:01<00:01,  9.97batch/s, loss=0.0101]

CRF Loss: 0.009494781494140625


Epoch 6/20:  50%|█████     | 16/32 [00:01<00:01, 10.07batch/s, loss=0.0103]

CRF Loss: 0.012554168701171875
CRF Loss: 0.008544921875


Epoch 6/20:  62%|██████▎   | 20/32 [00:02<00:01, 10.10batch/s, loss=0.0108]

CRF Loss: 0.010532379150390625
CRF Loss: 0.011758804321289062
CRF Loss: 0.012422561645507812


Epoch 6/20:  69%|██████▉   | 22/32 [00:02<00:00, 10.07batch/s, loss=0.0102]

CRF Loss: 0.007925033569335938
CRF Loss: 0.009210586547851562


Epoch 6/20:  75%|███████▌  | 24/32 [00:02<00:00,  9.96batch/s, loss=0.0101]

CRF Loss: 0.0113372802734375
CRF Loss: 0.0074787139892578125
CRF Loss: 0.009937286376953125


Epoch 6/20:  88%|████████▊ | 28/32 [00:02<00:00, 10.07batch/s, loss=0.0104]

CRF Loss: 0.007541656494140625
CRF Loss: 0.011083602905273438
CRF Loss: 0.013322830200195312


Epoch 6/20:  94%|█████████▍| 30/32 [00:03<00:00, 10.11batch/s, loss=0.0102]

CRF Loss: 0.0095062255859375
CRF Loss: 0.01161956787109375
CRF Loss: 0.0029449462890625


Epoch 6/20: 100%|██████████| 32/32 [00:03<00:00, 10.06batch/s, loss=0.0103]


cuda
Epoch 6/20 - Total Loss: 0.3194


Epoch 7/20:   3%|▎         | 1/32 [00:00<00:03,  9.26batch/s, loss=0.00808]

CRF Loss: 0.008077621459960938
CRF Loss: 0.008737564086914062


Epoch 7/20:  12%|█▎        | 4/32 [00:00<00:02, 10.12batch/s, loss=0.012]

CRF Loss: 0.011793136596679688
CRF Loss: 0.007427215576171875
CRF Loss: 0.00829315185546875


Epoch 7/20:  19%|█▉        | 6/32 [00:00<00:02,  9.90batch/s, loss=0.00859]

CRF Loss: 0.008478164672851562
CRF Loss: 0.007312774658203125


Epoch 7/20:  25%|██▌       | 8/32 [00:00<00:02, 10.02batch/s, loss=0.00871]

CRF Loss: 0.007823944091796875
CRF Loss: 0.010408401489257812
CRF Loss: 0.00606536865234375


Epoch 7/20:  38%|███▊      | 12/32 [00:01<00:02,  9.93batch/s, loss=0.0091]

CRF Loss: 0.009004592895507812
CRF Loss: 0.0067043304443359375


Epoch 7/20:  44%|████▍     | 14/32 [00:01<00:01, 10.03batch/s, loss=0.00899]

CRF Loss: 0.0075283050537109375
CRF Loss: 0.009246826171875
CRF Loss: 0.010206222534179688


Epoch 7/20:  50%|█████     | 16/32 [00:01<00:01,  9.94batch/s, loss=0.00827]

CRF Loss: 0.0062923431396484375
CRF Loss: 0.0071849822998046875


Epoch 7/20:  56%|█████▋    | 18/32 [00:01<00:01, 10.06batch/s, loss=0.00824]

CRF Loss: 0.007076263427734375
CRF Loss: 0.008855819702148438
CRF Loss: 0.010580062866210938


Epoch 7/20:  69%|██████▉   | 22/32 [00:02<00:01,  9.98batch/s, loss=0.00879]

CRF Loss: 0.007785797119140625
CRF Loss: 0.009611129760742188


Epoch 7/20:  75%|███████▌  | 24/32 [00:02<00:00, 10.08batch/s, loss=0.00884]

CRF Loss: 0.010896682739257812
CRF Loss: 0.00797271728515625
CRF Loss: 0.0063228607177734375


Epoch 7/20:  81%|████████▏ | 26/32 [00:02<00:00,  9.92batch/s, loss=0.00834]

CRF Loss: 0.008378982543945312
CRF Loss: 0.00710296630859375


Epoch 7/20:  88%|████████▊ | 28/32 [00:02<00:00, 10.06batch/s, loss=0.00839]

CRF Loss: 0.008028030395507812
CRF Loss: 0.009984970092773438
CRF Loss: 0.00853729248046875


Epoch 7/20: 100%|██████████| 32/32 [00:03<00:00, 10.04batch/s, loss=0.0086]


CRF Loss: 0.008783340454101562
CRF Loss: 0.00608062744140625
cuda
Epoch 7/20 - Total Loss: 0.2666


Epoch 8/20:   0%|          | 0/32 [00:00<?, ?batch/s]

CRF Loss: 0.0071430206298828125


Epoch 8/20:   3%|▎         | 1/32 [00:00<00:03,  9.68batch/s, loss=0.00774]

CRF Loss: 0.008331298828125


Epoch 8/20:   9%|▉         | 3/32 [00:00<00:02,  9.94batch/s, loss=0.012]

CRF Loss: 0.008592605590820312
CRF Loss: 0.0070629119873046875


Epoch 8/20:  16%|█▌        | 5/32 [00:00<00:02,  9.63batch/s, loss=0.00794]

CRF Loss: 0.008600234985351562
CRF Loss: 0.007900238037109375
CRF Loss: 0.00838470458984375


Epoch 8/20:  25%|██▌       | 8/32 [00:00<00:02,  8.62batch/s, loss=0.00794]

CRF Loss: 0.007503509521484375
CRF Loss: 0.00891876220703125


Epoch 8/20:  31%|███▏      | 10/32 [00:01<00:02,  7.63batch/s, loss=0.00787]

CRF Loss: 0.006256103515625
CRF Loss: 0.0074939727783203125


Epoch 8/20:  38%|███▊      | 12/32 [00:01<00:02,  7.55batch/s, loss=0.00771]

CRF Loss: 0.006389617919921875
CRF Loss: 0.005916595458984375


Epoch 8/20:  44%|████▍     | 14/32 [00:01<00:02,  7.62batch/s, loss=0.00756]

CRF Loss: 0.00740814208984375
CRF Loss: 0.007129669189453125


Epoch 8/20:  50%|█████     | 16/32 [00:01<00:02,  7.78batch/s, loss=0.00747]

CRF Loss: 0.0064792633056640625
CRF Loss: 0.008083343505859375


Epoch 8/20:  56%|█████▋    | 18/32 [00:02<00:01,  7.12batch/s, loss=0.00749]

CRF Loss: 0.0072956085205078125
CRF Loss: 0.00707244873046875


Epoch 8/20:  62%|██████▎   | 20/32 [00:02<00:01,  6.79batch/s, loss=0.00759]

CRF Loss: 0.009809494018554688
CRF Loss: 0.0056705474853515625


Epoch 8/20:  69%|██████▉   | 22/32 [00:02<00:01,  6.53batch/s, loss=0.00757]

CRF Loss: 0.00908660888671875
CRF Loss: 0.0062618255615234375


Epoch 8/20:  78%|███████▊  | 25/32 [00:03<00:00,  8.21batch/s, loss=0.00752]

CRF Loss: 0.009128570556640625
CRF Loss: 0.0060405731201171875


Epoch 8/20:  84%|████████▍ | 27/32 [00:03<00:00,  8.86batch/s, loss=0.00747]

CRF Loss: 0.00682830810546875
CRF Loss: 0.0068817138671875


Epoch 8/20:  88%|████████▊ | 28/32 [00:03<00:00,  9.10batch/s, loss=0.00742]

CRF Loss: 0.008844375610351562
CRF Loss: 0.0047931671142578125
CRF Loss: 0.004306793212890625


Epoch 8/20: 100%|██████████| 32/32 [00:03<00:00,  8.20batch/s, loss=0.00725]


CRF Loss: 0.0070209503173828125
CRF Loss: 0.00536346435546875
cuda
Epoch 8/20 - Total Loss: 0.2320


Epoch 9/20:   0%|          | 0/32 [00:00<?, ?batch/s]

CRF Loss: 0.0067462921142578125


Epoch 9/20:   6%|▋         | 2/32 [00:00<00:03,  9.50batch/s, loss=0.00644]

CRF Loss: 0.006130218505859375


Epoch 9/20:   9%|▉         | 3/32 [00:00<00:02,  9.71batch/s, loss=0.00657]

CRF Loss: 0.0068359375


Epoch 9/20:   9%|▉         | 3/32 [00:00<00:02,  9.71batch/s, loss=0.00647]

CRF Loss: 0.0061779022216796875
CRF Loss: 0.007869720458984375


Epoch 9/20:  22%|██▏       | 7/32 [00:00<00:02,  9.80batch/s, loss=0.00778]

CRF Loss: 0.006145477294921875
CRF Loss: 0.0067882537841796875
CRF Loss: 0.0044727325439453125


Epoch 9/20:  28%|██▊       | 9/32 [00:01<00:02,  9.92batch/s, loss=0.0064] 

CRF Loss: 0.00440216064453125
CRF Loss: 0.008440017700195312
CRF Loss: 0.00548553466796875


Epoch 9/20:  41%|████      | 13/32 [00:01<00:01,  9.90batch/s, loss=0.00694]

CRF Loss: 0.0070781707763671875
CRF Loss: 0.0066738128662109375


Epoch 9/20:  47%|████▋     | 15/32 [00:01<00:01,  9.91batch/s, loss=0.00684]

CRF Loss: 0.00516510009765625
CRF Loss: 0.007305145263671875
CRF Loss: 0.0042705535888671875


Epoch 9/20:  56%|█████▋    | 18/32 [00:01<00:01, 10.08batch/s, loss=0.00666]

CRF Loss: 0.005893707275390625
CRF Loss: 0.007404327392578125
CRF Loss: 0.0056705474853515625


Epoch 9/20:  62%|██████▎   | 20/32 [00:02<00:01, 10.17batch/s, loss=0.00625]

CRF Loss: 0.006771087646484375
CRF Loss: 0.0056133270263671875
CRF Loss: 0.00817108154296875


Epoch 9/20:  75%|███████▌  | 24/32 [00:02<00:00, 10.08batch/s, loss=0.0066]

CRF Loss: 0.0059719085693359375
CRF Loss: 0.0062618255615234375
CRF Loss: 0.007541656494140625


Epoch 9/20:  81%|████████▏ | 26/32 [00:02<00:00,  9.96batch/s, loss=0.00627]

CRF Loss: 0.0037384033203125
CRF Loss: 0.0062408447265625
CRF Loss: 0.005985260009765625


Epoch 9/20:  94%|█████████▍| 30/32 [00:03<00:00, 10.14batch/s, loss=0.00648]

CRF Loss: 0.0070858001708984375
CRF Loss: 0.00553131103515625
CRF Loss: 0.007663726806640625


Epoch 9/20: 100%|██████████| 32/32 [00:03<00:00, 10.03batch/s, loss=0.0066]


CRF Loss: 0.0089569091796875
cuda
Epoch 9/20 - Total Loss: 0.2045


Epoch 10/20:   3%|▎         | 1/32 [00:00<00:03,  9.02batch/s, loss=0.00596]

CRF Loss: 0.0059566497802734375
CRF Loss: 0.00567626953125


Epoch 10/20:   9%|▉         | 3/32 [00:00<00:03,  9.49batch/s, loss=0.00625]

CRF Loss: 0.0065402984619140625
CRF Loss: 0.0068302154541015625


Epoch 10/20:  16%|█▌        | 5/32 [00:00<00:02,  9.88batch/s, loss=0.00616]

CRF Loss: 0.0046291351318359375
CRF Loss: 0.0073337554931640625
CRF Loss: 0.0055389404296875


Epoch 10/20:  28%|██▊       | 9/32 [00:00<00:02, 10.12batch/s, loss=0.00657]

CRF Loss: 0.0041046142578125
CRF Loss: 0.0059814453125
CRF Loss: 0.0045261383056640625


Epoch 10/20:  34%|███▍      | 11/32 [00:01<00:02,  9.92batch/s, loss=0.00552]

CRF Loss: 0.0052547454833984375
CRF Loss: 0.003910064697265625


Epoch 10/20:  41%|████      | 13/32 [00:01<00:01,  9.78batch/s, loss=0.00558]

CRF Loss: 0.00566864013671875
CRF Loss: 0.00614166259765625


Epoch 10/20:  47%|████▋     | 15/32 [00:01<00:01,  9.93batch/s, loss=0.00556]

CRF Loss: 0.0060710906982421875
CRF Loss: 0.0048274993896484375
CRF Loss: 0.004856109619140625


Epoch 10/20:  59%|█████▉    | 19/32 [00:01<00:01, 10.07batch/s, loss=0.0057]

CRF Loss: 0.0046100616455078125
CRF Loss: 0.0041141510009765625
CRF Loss: 0.004840850830078125


Epoch 10/20:  66%|██████▌   | 21/32 [00:02<00:01,  9.91batch/s, loss=0.00535]

CRF Loss: 0.0053653717041015625
CRF Loss: 0.00484466552734375


Epoch 10/20:  72%|███████▏  | 23/32 [00:02<00:00,  9.85batch/s, loss=0.00542]

CRF Loss: 0.0040187835693359375
CRF Loss: 0.008350372314453125


Epoch 10/20:  78%|███████▊  | 25/32 [00:02<00:00, 10.01batch/s, loss=0.00544]

CRF Loss: 0.005924224853515625
CRF Loss: 0.00543212890625
CRF Loss: 0.0073986053466796875


Epoch 10/20:  91%|█████████ | 29/32 [00:02<00:00, 10.15batch/s, loss=0.00564]

CRF Loss: 0.004772186279296875
CRF Loss: 0.004398345947265625
CRF Loss: 0.0068187713623046875


Epoch 10/20: 100%|██████████| 32/32 [00:03<00:00, 10.01batch/s, loss=0.00547]


CRF Loss: 0.0046977996826171875
CRF Loss: 0.00563812255859375
cuda
Epoch 10/20 - Total Loss: 0.1751


Epoch 11/20:   3%|▎         | 1/32 [00:00<00:03,  8.78batch/s, loss=0.00518]

CRF Loss: 0.00550079345703125
CRF Loss: 0.0048580169677734375


Epoch 11/20:   9%|▉         | 3/32 [00:00<00:02,  9.69batch/s, loss=0.00495]

CRF Loss: 0.0058460235595703125
CRF Loss: 0.0035762786865234375
CRF Loss: 0.003620147705078125


Epoch 11/20:  22%|██▏       | 7/32 [00:00<00:02, 10.12batch/s, loss=0.00554]

CRF Loss: 0.0053195953369140625
CRF Loss: 0.0045108795166015625
CRF Loss: 0.005512237548828125


Epoch 11/20:  28%|██▊       | 9/32 [00:01<00:02, 10.14batch/s, loss=0.00468]

CRF Loss: 0.0040874481201171875
CRF Loss: 0.0039424896240234375
CRF Loss: 0.0070095062255859375


Epoch 11/20:  41%|████      | 13/32 [00:01<00:01, 10.05batch/s, loss=0.00524]

CRF Loss: 0.00495147705078125
CRF Loss: 0.0041294097900390625
CRF Loss: 0.004657745361328125


Epoch 11/20:  47%|████▋     | 15/32 [00:01<00:01, 10.14batch/s, loss=0.00482]

CRF Loss: 0.0033168792724609375
CRF Loss: 0.006214141845703125
CRF Loss: 0.006206512451171875


Epoch 11/20:  59%|█████▉    | 19/32 [00:01<00:01, 10.19batch/s, loss=0.00529]

CRF Loss: 0.0061702728271484375
CRF Loss: 0.0057811737060546875
CRF Loss: 0.0051670074462890625


Epoch 11/20:  66%|██████▌   | 21/32 [00:02<00:01,  9.79batch/s, loss=0.00499]

CRF Loss: 0.005001068115234375
CRF Loss: 0.0044841766357421875


Epoch 11/20:  72%|███████▏  | 23/32 [00:02<00:00,  9.92batch/s, loss=0.00492]

CRF Loss: 0.0040493011474609375
CRF Loss: 0.004131317138671875
CRF Loss: 0.004109382629394531


Epoch 11/20:  78%|███████▊  | 25/32 [00:02<00:00, 10.02batch/s, loss=0.00485]

CRF Loss: 0.00402069091796875
CRF Loss: 0.0059661865234375


Epoch 11/20:  88%|████████▊ | 28/32 [00:02<00:00,  8.58batch/s, loss=0.00487]

CRF Loss: 0.0042858123779296875
CRF Loss: 0.00489044189453125


Epoch 11/20:  94%|█████████▍| 30/32 [00:03<00:00,  7.73batch/s, loss=0.00486]

CRF Loss: 0.004608154296875
CRF Loss: 0.0046062469482421875


Epoch 11/20: 100%|██████████| 32/32 [00:03<00:00,  9.24batch/s, loss=0.00494]


CRF Loss: 0.0074920654296875
cuda
Epoch 11/20 - Total Loss: 0.1580


Epoch 12/20:   0%|          | 0/32 [00:00<?, ?batch/s]

CRF Loss: 0.006092071533203125


Epoch 12/20:   3%|▎         | 1/32 [00:00<00:04,  6.66batch/s, loss=0.00609]

CRF Loss: 0.005431175231933594


Epoch 12/20:   6%|▋         | 2/32 [00:00<00:04,  7.13batch/s, loss=0.00576]

CRF Loss: 0.004985809326171875


Epoch 12/20:   9%|▉         | 3/32 [00:00<00:03,  7.45batch/s, loss=0.0055]

CRF Loss: 0.006061553955078125


Epoch 12/20:  12%|█▎        | 4/32 [00:00<00:03,  7.52batch/s, loss=0.00564]

CRF Loss: 0.0040607452392578125


Epoch 12/20:  16%|█▌        | 5/32 [00:00<00:03,  6.79batch/s, loss=0.00533]

CRF Loss: 0.005817413330078125


Epoch 12/20:  19%|█▉        | 6/32 [00:00<00:03,  6.76batch/s, loss=0.00541]

CRF Loss: 0.0041408538818359375


Epoch 12/20:  22%|██▏       | 7/32 [00:01<00:03,  6.47batch/s, loss=0.00523]

CRF Loss: 0.004169464111328125


Epoch 12/20:  25%|██▌       | 8/32 [00:01<00:03,  6.27batch/s, loss=0.00509]

CRF Loss: 0.00428009033203125


Epoch 12/20:  28%|██▊       | 9/32 [00:01<00:03,  6.28batch/s, loss=0.005]

CRF Loss: 0.0056018829345703125


Epoch 12/20:  34%|███▍      | 11/32 [00:01<00:02,  7.55batch/s, loss=0.00502]

CRF Loss: 0.0045757293701171875


Epoch 12/20:  38%|███▊      | 12/32 [00:01<00:02,  7.71batch/s, loss=0.00487]

CRF Loss: 0.0032672882080078125


Epoch 12/20:  41%|████      | 13/32 [00:01<00:02,  8.27batch/s, loss=0.00489]

CRF Loss: 0.005061149597167969


Epoch 12/20:  44%|████▍     | 14/32 [00:01<00:02,  8.68batch/s, loss=0.00496]

CRF Loss: 0.005947113037109375
CRF Loss: 0.0059986114501953125


Epoch 12/20:  53%|█████▎    | 17/32 [00:02<00:01,  9.46batch/s, loss=0.00493]

CRF Loss: 0.0038089752197265625
CRF Loss: 0.004528045654296875
CRF Loss: 0.0054111480712890625


Epoch 12/20:  59%|█████▉    | 19/32 [00:02<00:01,  9.64batch/s, loss=0.00486]

CRF Loss: 0.0036916732788085938
CRF Loss: 0.004191398620605469
CRF Loss: 0.0036802291870117188


Epoch 12/20:  69%|██████▉   | 22/32 [00:02<00:01,  9.43batch/s, loss=0.00477]

CRF Loss: 0.004776954650878906
CRF Loss: 0.0041065216064453125
CRF Loss: 0.0055866241455078125


Epoch 12/20:  78%|███████▊  | 25/32 [00:03<00:00,  9.77batch/s, loss=0.00478]

CRF Loss: 0.0042858123779296875
CRF Loss: 0.004624366760253906
CRF Loss: 0.0035753250122070312


Epoch 12/20:  88%|████████▊ | 28/32 [00:03<00:00,  9.84batch/s, loss=0.00466]

CRF Loss: 0.0036077499389648438
CRF Loss: 0.00370025634765625
CRF Loss: 0.0035390853881835938


Epoch 12/20: 100%|██████████| 32/32 [00:03<00:00,  8.57batch/s, loss=0.00454]


CRF Loss: 0.003276824951171875
CRF Loss: 0.00347137451171875
cuda
Epoch 12/20 - Total Loss: 0.1454


Epoch 13/20:   6%|▋         | 2/32 [00:00<00:02, 10.21batch/s, loss=0.00701]

CRF Loss: 0.00334930419921875
CRF Loss: 0.0036563873291015625
CRF Loss: 0.004252433776855469


Epoch 13/20:  12%|█▎        | 4/32 [00:00<00:02, 10.16batch/s, loss=0.00417]

CRF Loss: 0.005471229553222656
CRF Loss: 0.0041027069091796875
CRF Loss: 0.004307746887207031


Epoch 13/20:  25%|██▌       | 8/32 [00:00<00:02, 10.14batch/s, loss=0.00474]

CRF Loss: 0.004431724548339844
CRF Loss: 0.0036420822143554688
CRF Loss: 0.0036163330078125


Epoch 13/20:  31%|███▏      | 10/32 [00:01<00:02,  9.73batch/s, loss=0.00408]

CRF Loss: 0.0038280487060546875
CRF Loss: 0.00417327880859375


Epoch 13/20:  38%|███▊      | 12/32 [00:01<00:02,  9.91batch/s, loss=0.004]  

CRF Loss: 0.0034914016723632812
CRF Loss: 0.003681182861328125
CRF Loss: 0.0047473907470703125


Epoch 13/20:  50%|█████     | 16/32 [00:01<00:01, 10.08batch/s, loss=0.00425]

CRF Loss: 0.0033359527587890625
CRF Loss: 0.0037107467651367188
CRF Loss: 0.003833770751953125


Epoch 13/20:  56%|█████▋    | 18/32 [00:01<00:01, 10.13batch/s, loss=0.00399]

CRF Loss: 0.0035333633422851562
CRF Loss: 0.004616737365722656


Epoch 13/20:  62%|██████▎   | 20/32 [00:02<00:01,  9.78batch/s, loss=0.00407]

CRF Loss: 0.004550933837890625
CRF Loss: 0.005070686340332031


Epoch 13/20:  69%|██████▉   | 22/32 [00:02<00:01,  9.90batch/s, loss=0.00425]

CRF Loss: 0.0061492919921875
CRF Loss: 0.006104469299316406
CRF Loss: 0.005515098571777344


Epoch 13/20:  81%|████████▏ | 26/32 [00:02<00:00, 10.07batch/s, loss=0.00438]

CRF Loss: 0.0032958984375
CRF Loss: 0.0031118392944335938
CRF Loss: 0.0038175582885742188


Epoch 13/20:  88%|████████▊ | 28/32 [00:02<00:00, 10.12batch/s, loss=0.00426]

CRF Loss: 0.004791259765625
CRF Loss: 0.005459785461425781


Epoch 13/20:  94%|█████████▍| 30/32 [00:03<00:00,  9.82batch/s, loss=0.0043] 

CRF Loss: 0.004187583923339844
CRF Loss: 0.0053310394287109375


Epoch 13/20: 100%|██████████| 32/32 [00:03<00:00, 10.02batch/s, loss=0.00445]


CRF Loss: 0.004638671875
cuda
Epoch 13/20 - Total Loss: 0.1378


Epoch 14/20:   3%|▎         | 1/32 [00:00<00:03,  9.95batch/s, loss=0.00418]

CRF Loss: 0.004177093505859375
CRF Loss: 0.004948616027832031


Epoch 14/20:   6%|▋         | 2/32 [00:00<00:03,  9.90batch/s, loss=0.00456]

CRF Loss: 0.004803657531738281


Epoch 14/20:   9%|▉         | 3/32 [00:00<00:02,  9.84batch/s, loss=0.00446]

CRF Loss: 0.00392913818359375


Epoch 14/20:  16%|█▌        | 5/32 [00:00<00:02,  9.86batch/s, loss=0.00544]

CRF Loss: 0.0038852691650390625


Epoch 14/20:  16%|█▌        | 5/32 [00:00<00:02,  9.86batch/s, loss=0.00429]

CRF Loss: 0.004000663757324219
CRF Loss: 0.00313568115234375


Epoch 14/20:  28%|██▊       | 9/32 [00:00<00:02,  9.22batch/s, loss=0.00409]

CRF Loss: 0.0039043426513671875
CRF Loss: 0.003993988037109375


Epoch 14/20:  31%|███▏      | 10/32 [00:01<00:02,  9.32batch/s, loss=0.00395]

CRF Loss: 0.0035638809204101562
CRF Loss: 0.0031423568725585938
CRF Loss: 0.0049896240234375


Epoch 14/20:  44%|████▍     | 14/32 [00:01<00:01,  9.89batch/s, loss=0.00436]

CRF Loss: 0.004521369934082031
CRF Loss: 0.0037450790405273438
CRF Loss: 0.0021953582763671875


Epoch 14/20:  53%|█████▎    | 17/32 [00:01<00:01,  9.74batch/s, loss=0.00415]

CRF Loss: 0.0033655166625976562
CRF Loss: 0.00415802001953125


Epoch 14/20:  56%|█████▋    | 18/32 [00:01<00:01,  9.57batch/s, loss=0.00386]

CRF Loss: 0.004584312438964844
CRF Loss: 0.0023708343505859375


Epoch 14/20:  66%|██████▌   | 21/32 [00:02<00:01,  9.82batch/s, loss=0.00385]

CRF Loss: 0.004169464111328125
CRF Loss: 0.003238677978515625
CRF Loss: 0.0039577484130859375


Epoch 14/20:  75%|███████▌  | 24/32 [00:02<00:00,  9.96batch/s, loss=0.00391]

CRF Loss: 0.004759788513183594
CRF Loss: 0.004267692565917969
CRF Loss: 0.0046291351318359375


Epoch 14/20:  84%|████████▍ | 27/32 [00:02<00:00,  9.69batch/s, loss=0.00396]

CRF Loss: 0.004520416259765625
CRF Loss: 0.004000663757324219


Epoch 14/20:  88%|████████▊ | 28/32 [00:03<00:00,  9.48batch/s, loss=0.004]  

CRF Loss: 0.004574775695800781
CRF Loss: 0.004399299621582031


Epoch 14/20:  94%|█████████▍| 30/32 [00:03<00:00,  9.72batch/s, loss=0.00403]

CRF Loss: 0.004603385925292969
CRF Loss: 0.004352569580078125
CRF Loss: 0.00473785400390625


Epoch 14/20: 100%|██████████| 32/32 [00:03<00:00,  9.73batch/s, loss=0.00418]


cuda
Epoch 14/20 - Total Loss: 0.1296


Epoch 15/20:   3%|▎         | 1/32 [00:00<00:03,  9.68batch/s, loss=0.00325]

CRF Loss: 0.0032548904418945312
CRF Loss: 0.002956390380859375


Epoch 15/20:  12%|█▎        | 4/32 [00:00<00:03,  9.16batch/s, loss=0.00334]

CRF Loss: 0.0035037994384765625
CRF Loss: 0.0036296844482421875


Epoch 15/20:  19%|█▉        | 6/32 [00:00<00:02,  9.35batch/s, loss=0.00336]

CRF Loss: 0.004302978515625
CRF Loss: 0.0025224685668945312


Epoch 15/20:  22%|██▏       | 7/32 [00:00<00:02,  9.50batch/s, loss=0.0035] 

CRF Loss: 0.0034437179565429688
CRF Loss: 0.00435638427734375
CRF Loss: 0.0036573410034179688


Epoch 15/20:  34%|███▍      | 11/32 [00:01<00:02,  9.95batch/s, loss=0.00403]

CRF Loss: 0.004021644592285156
CRF Loss: 0.0046215057373046875
CRF Loss: 0.002727508544921875


Epoch 15/20:  41%|████      | 13/32 [00:01<00:02,  8.33batch/s, loss=0.00353]

CRF Loss: 0.0029449462890625
CRF Loss: 0.0033197402954101562


Epoch 15/20:  47%|████▋     | 15/32 [00:01<00:02,  7.34batch/s, loss=0.00351]

CRF Loss: 0.0034580230712890625
CRF Loss: 0.004355430603027344


Epoch 15/20:  53%|█████▎    | 17/32 [00:02<00:02,  7.32batch/s, loss=0.00357]

CRF Loss: 0.0035338401794433594
CRF Loss: 0.004990577697753906


Epoch 15/20:  59%|█████▉    | 19/32 [00:02<00:01,  7.51batch/s, loss=0.00367]

CRF Loss: 0.004130840301513672
CRF Loss: 0.0034513473510742188


Epoch 15/20:  66%|██████▌   | 21/32 [00:02<00:01,  7.62batch/s, loss=0.00365]

CRF Loss: 0.0034170150756835938
CRF Loss: 0.0049991607666015625


Epoch 15/20:  72%|███████▏  | 23/32 [00:02<00:01,  7.06batch/s, loss=0.00374]

CRF Loss: 0.004367351531982422
CRF Loss: 0.0036792755126953125


Epoch 15/20:  78%|███████▊  | 25/32 [00:03<00:01,  6.55batch/s, loss=0.00374]

CRF Loss: 0.0037555694580078125
CRF Loss: 0.0035085678100585938


Epoch 15/20:  84%|████████▍ | 27/32 [00:03<00:00,  6.91batch/s, loss=0.00377]

CRF Loss: 0.005053043365478516
CRF Loss: 0.00372314453125


Epoch 15/20:  94%|█████████▍| 30/32 [00:03<00:00,  8.35batch/s, loss=0.00381]

CRF Loss: 0.0038785934448242188
CRF Loss: 0.0047206878662109375


Epoch 15/20: 100%|██████████| 32/32 [00:03<00:00,  8.10batch/s, loss=0.00383]


CRF Loss: 0.004031181335449219
CRF Loss: 0.0041046142578125
cuda
Epoch 15/20 - Total Loss: 0.1224


Epoch 16/20:   0%|          | 0/32 [00:00<?, ?batch/s]

CRF Loss: 0.004134178161621094


Epoch 16/20:   6%|▋         | 2/32 [00:00<00:03,  9.80batch/s, loss=0.00297]

CRF Loss: 0.0017986297607421875


Epoch 16/20:   9%|▉         | 3/32 [00:00<00:02,  9.76batch/s, loss=0.00321]

CRF Loss: 0.0036950111389160156


Epoch 16/20:  12%|█▎        | 4/32 [00:00<00:02,  9.83batch/s, loss=0.00343]

CRF Loss: 0.00409698486328125


Epoch 16/20:  16%|█▌        | 5/32 [00:00<00:02,  9.32batch/s, loss=0.00356]

CRF Loss: 0.004062652587890625


Epoch 16/20:  19%|█▉        | 6/32 [00:00<00:02,  9.50batch/s, loss=0.0035]

CRF Loss: 0.0032396316528320312


Epoch 16/20:  22%|██▏       | 7/32 [00:00<00:02,  9.61batch/s, loss=0.00347]

CRF Loss: 0.0032405853271484375
CRF Loss: 0.0027189254760742188


Epoch 16/20:  31%|███▏      | 10/32 [00:01<00:02,  9.78batch/s, loss=0.00377]

CRF Loss: 0.0034303665161132812
CRF Loss: 0.0035085678100585938
CRF Loss: 0.003937244415283203


Epoch 16/20:  38%|███▊      | 12/32 [00:01<00:02,  9.91batch/s, loss=0.0035] 

CRF Loss: 0.0032935142517089844
CRF Loss: 0.00428009033203125
CRF Loss: 0.0032396316528320312


Epoch 16/20:  47%|████▋     | 15/32 [00:01<00:01,  9.75batch/s, loss=0.00343]

CRF Loss: 0.0029888153076171875
CRF Loss: 0.003169536590576172
CRF Loss: 0.0023899078369140625


Epoch 16/20:  56%|█████▋    | 18/32 [00:01<00:01,  9.70batch/s, loss=0.00339]

CRF Loss: 0.0032663345336914062
CRF Loss: 0.0039424896240234375


Epoch 16/20:  62%|██████▎   | 20/32 [00:02<00:01,  9.86batch/s, loss=0.00341]

CRF Loss: 0.004503726959228516
CRF Loss: 0.0026209354400634766
CRF Loss: 0.004607200622558594


Epoch 16/20:  75%|███████▌  | 24/32 [00:02<00:00,  9.79batch/s, loss=0.00347]

CRF Loss: 0.0034046173095703125
CRF Loss: 0.003745555877685547


Epoch 16/20:  78%|███████▊  | 25/32 [00:02<00:00,  9.69batch/s, loss=0.00348]

CRF Loss: 0.0038957595825195312
CRF Loss: 0.0033903121948242188
CRF Loss: 0.006262779235839844


Epoch 16/20:  88%|████████▊ | 28/32 [00:02<00:00,  9.77batch/s, loss=0.00354]

CRF Loss: 0.003155231475830078
CRF Loss: 0.002748250961303711


Epoch 16/20:  94%|█████████▍| 30/32 [00:03<00:00,  9.87batch/s, loss=0.00352]

CRF Loss: 0.0035376548767089844
CRF Loss: 0.0028228759765625
CRF Loss: 0.0068264007568359375


Epoch 16/20: 100%|██████████| 32/32 [00:03<00:00,  9.83batch/s, loss=0.00374]


cuda
Epoch 16/20 - Total Loss: 0.1160


Epoch 17/20:   3%|▎         | 1/32 [00:00<00:03,  9.71batch/s, loss=0.00395]

CRF Loss: 0.003947257995605469
CRF Loss: 0.0030357837677001953


Epoch 17/20:  12%|█▎        | 4/32 [00:00<00:02,  9.89batch/s, loss=0.00437]

CRF Loss: 0.003658294677734375
CRF Loss: 0.0024628639221191406
CRF Loss: 0.0028629302978515625


Epoch 17/20:  19%|█▉        | 6/32 [00:00<00:02,  9.62batch/s, loss=0.00314]

CRF Loss: 0.003032207489013672
CRF Loss: 0.002956867218017578


Epoch 17/20:  25%|██▌       | 8/32 [00:00<00:02,  9.84batch/s, loss=0.00343]

CRF Loss: 0.004789829254150391
CRF Loss: 0.004080295562744141
CRF Loss: 0.0025331974029541016


Epoch 17/20:  38%|███▊      | 12/32 [00:01<00:02,  9.89batch/s, loss=0.00372]

CRF Loss: 0.0034689903259277344
CRF Loss: 0.004089832305908203


Epoch 17/20:  41%|████      | 13/32 [00:01<00:01,  9.90batch/s, loss=0.00336]

CRF Loss: 0.0036911964416503906
CRF Loss: 0.0024704933166503906
CRF Loss: 0.0030570030212402344


Epoch 17/20:  50%|█████     | 16/32 [00:01<00:01,  9.66batch/s, loss=0.00335]

CRF Loss: 0.0029941797256469727
CRF Loss: 0.0037488937377929688


Epoch 17/20:  56%|█████▋    | 18/32 [00:01<00:01,  9.88batch/s, loss=0.00331]

CRF Loss: 0.002629518508911133
CRF Loss: 0.003336668014526367
CRF Loss: 0.003394603729248047


Epoch 17/20:  69%|██████▉   | 22/32 [00:02<00:01,  9.87batch/s, loss=0.00349]

CRF Loss: 0.0038068294525146484
CRF Loss: 0.0032575130462646484


Epoch 17/20:  72%|███████▏  | 23/32 [00:02<00:00,  9.79batch/s, loss=0.00332]

CRF Loss: 0.0028362274169921875
CRF Loss: 0.0035409927368164062
CRF Loss: 0.0028023719787597656


Epoch 17/20:  81%|████████▏ | 26/32 [00:02<00:00,  9.78batch/s, loss=0.00332]

CRF Loss: 0.00408172607421875
CRF Loss: 0.003105640411376953


Epoch 17/20:  91%|█████████ | 29/32 [00:02<00:00,  9.90batch/s, loss=0.00335]

CRF Loss: 0.002825498580932617
CRF Loss: 0.004627704620361328


Epoch 17/20:  94%|█████████▍| 30/32 [00:03<00:00,  9.84batch/s, loss=0.00331]

CRF Loss: 0.003113985061645508
CRF Loss: 0.002492189407348633


Epoch 17/20: 100%|██████████| 32/32 [00:03<00:00,  9.82batch/s, loss=0.00342]


CRF Loss: 0.0034084320068359375
cuda
Epoch 17/20 - Total Loss: 0.1061


Epoch 18/20:   3%|▎         | 1/32 [00:00<00:03,  9.42batch/s, loss=0.00311]

CRF Loss: 0.003114461898803711


Epoch 18/20:   6%|▋         | 2/32 [00:00<00:03,  9.40batch/s, loss=0.0031]

CRF Loss: 0.003093719482421875


Epoch 18/20:   6%|▋         | 2/32 [00:00<00:03,  9.40batch/s, loss=0.00291]

CRF Loss: 0.002508878707885742
CRF Loss: 0.0027289390563964844


Epoch 18/20:  19%|█▉        | 6/32 [00:00<00:02,  9.87batch/s, loss=0.00337]

CRF Loss: 0.0018880367279052734
CRF Loss: 0.003536224365234375
CRF Loss: 0.002091318368911743


Epoch 18/20:  28%|██▊       | 9/32 [00:00<00:02,  9.76batch/s, loss=0.00283]

CRF Loss: 0.0026841163635253906
CRF Loss: 0.0038616955280303955


Epoch 18/20:  31%|███▏      | 10/32 [00:01<00:02,  9.50batch/s, loss=0.00288]

CRF Loss: 0.002566888928413391
CRF Loss: 0.0036373138427734375
CRF Loss: 0.0035152360796928406


Epoch 18/20:  44%|████▍     | 14/32 [00:01<00:01,  9.63batch/s, loss=0.00282]

CRF Loss: 0.0017584562301635742
CRF Loss: 0.002530515193939209


Epoch 18/20:  47%|████▋     | 15/32 [00:01<00:01,  9.59batch/s, loss=0.00281]

CRF Loss: 0.003022909164428711
CRF Loss: 0.002369523048400879
CRF Loss: 0.0031228065490722656


Epoch 18/20:  59%|█████▉    | 19/32 [00:01<00:01,  9.82batch/s, loss=0.00293]

CRF Loss: 0.002689763903617859
CRF Loss: 0.0019527152180671692


Epoch 18/20:  62%|██████▎   | 20/32 [00:02<00:01,  9.78batch/s, loss=0.00283]

CRF Loss: 0.003218531608581543
CRF Loss: 0.003473997116088867
CRF Loss: 0.002811908721923828


Epoch 18/20:  75%|███████▌  | 24/32 [00:02<00:00,  9.67batch/s, loss=0.00284]

CRF Loss: 0.002815723419189453
CRF Loss: 0.003092065453529358


Epoch 18/20:  78%|███████▊  | 25/32 [00:02<00:00,  9.69batch/s, loss=0.00292]

CRF Loss: 0.003516584634780884
CRF Loss: 0.004437759518623352
CRF Loss: 0.0035734176635742188


Epoch 18/20:  84%|████████▍ | 27/32 [00:02<00:00,  9.82batch/s, loss=0.00296]

CRF Loss: 0.0033665671944618225
CRF Loss: 0.0032096803188323975


Epoch 18/20:  94%|█████████▍| 30/32 [00:03<00:00,  8.52batch/s, loss=0.003]

CRF Loss: 0.003765612840652466
CRF Loss: 0.003588557243347168


Epoch 18/20: 100%|██████████| 32/32 [00:03<00:00,  9.33batch/s, loss=0.00302]


CRF Loss: 0.003211498260498047
cuda
Epoch 18/20 - Total Loss: 0.0968


Epoch 19/20:   0%|          | 0/32 [00:00<?, ?batch/s]

CRF Loss: 0.003108687698841095


Epoch 19/20:   3%|▎         | 1/32 [00:00<00:04,  6.76batch/s, loss=0.00311]

CRF Loss: 0.0030022263526916504


Epoch 19/20:   6%|▋         | 2/32 [00:00<00:04,  7.09batch/s, loss=0.00306]

CRF Loss: 0.0032756924629211426


Epoch 19/20:   9%|▉         | 3/32 [00:00<00:04,  6.90batch/s, loss=0.00313]

CRF Loss: 0.0029463768005371094


Epoch 19/20:  12%|█▎        | 4/32 [00:00<00:04,  6.67batch/s, loss=0.00308]

CRF Loss: 0.002733886241912842


Epoch 19/20:  16%|█▌        | 5/32 [00:00<00:03,  6.88batch/s, loss=0.00301]

CRF Loss: 0.002826377749443054


Epoch 19/20:  19%|█▉        | 6/32 [00:00<00:03,  6.88batch/s, loss=0.00298]

CRF Loss: 0.003286898136138916


Epoch 19/20:  22%|██▏       | 7/32 [00:01<00:03,  6.63batch/s, loss=0.00303]

CRF Loss: 0.002715587615966797


Epoch 19/20:  25%|██▌       | 8/32 [00:01<00:03,  6.45batch/s, loss=0.00299]

CRF Loss: 0.002189159393310547


Epoch 19/20:  28%|██▊       | 9/32 [00:01<00:03,  6.31batch/s, loss=0.0029]

CRF Loss: 0.002077817916870117


Epoch 19/20:  31%|███▏      | 10/32 [00:01<00:03,  6.07batch/s, loss=0.00282]

CRF Loss: 0.0033528506755828857


Epoch 19/20:  34%|███▍      | 11/32 [00:01<00:03,  6.21batch/s, loss=0.00287]

CRF Loss: 0.0029439404606819153


Epoch 19/20:  38%|███▊      | 12/32 [00:01<00:02,  6.73batch/s, loss=0.00293]

CRF Loss: 0.003602445125579834


Epoch 19/20:  44%|████▍     | 14/32 [00:02<00:02,  7.94batch/s, loss=0.00316]

CRF Loss: 0.0030423402786254883
CRF Loss: 0.0026034116744995117


Epoch 19/20:  53%|█████▎    | 17/32 [00:02<00:01,  8.88batch/s, loss=0.00292]

CRF Loss: 0.0032693445682525635
CRF Loss: 0.002621769905090332
CRF Loss: 0.0031577348709106445


Epoch 19/20:  59%|█████▉    | 19/32 [00:02<00:01,  9.11batch/s, loss=0.00296]

CRF Loss: 0.003191351890563965
CRF Loss: 0.003276720643043518


Epoch 19/20:  66%|██████▌   | 21/32 [00:02<00:01,  9.53batch/s, loss=0.00293]

CRF Loss: 0.0029734671115875244
CRF Loss: 0.0021935701370239258
CRF Loss: 0.0020166486501693726


Epoch 19/20:  78%|███████▊  | 25/32 [00:03<00:00,  9.79batch/s, loss=0.00292]

CRF Loss: 0.0030974149703979492
CRF Loss: 0.003507383167743683


Epoch 19/20:  84%|████████▍ | 27/32 [00:03<00:00,  9.59batch/s, loss=0.0029]

CRF Loss: 0.0026335716247558594
CRF Loss: 0.0025899112224578857
CRF Loss: 0.0029942989349365234


Epoch 19/20:  91%|█████████ | 29/32 [00:03<00:00,  9.50batch/s, loss=0.00289]

CRF Loss: 0.00292002409696579
CRF Loss: 0.0024214237928390503
CRF Loss: 0.0026128292083740234


Epoch 19/20: 100%|██████████| 32/32 [00:03<00:00,  8.31batch/s, loss=0.00282]


CRF Loss: 0.0012106895446777344
cuda
Epoch 19/20 - Total Loss: 0.0904


Epoch 20/20:   3%|▎         | 1/32 [00:00<00:03,  9.82batch/s, loss=0.00249]

CRF Loss: 0.002494625747203827
CRF Loss: 0.0025854408740997314


Epoch 20/20:  12%|█▎        | 4/32 [00:00<00:02,  9.72batch/s, loss=0.0034]

CRF Loss: 0.0022118762135505676
CRF Loss: 0.002900034189224243


Epoch 20/20:  16%|█▌        | 5/32 [00:00<00:02,  9.66batch/s, loss=0.00278]

CRF Loss: 0.0028453171253204346
CRF Loss: 0.003672376275062561


Epoch 20/20:  22%|██▏       | 7/32 [00:00<00:02,  9.62batch/s, loss=0.00286]

CRF Loss: 0.0020583271980285645
CRF Loss: 0.004081249237060547
CRF Loss: 0.0019321292638778687


Epoch 20/20:  34%|███▍      | 11/32 [00:01<00:02,  9.89batch/s, loss=0.00295]

CRF Loss: 0.002161383628845215
CRF Loss: 0.002600952982902527
CRF Loss: 0.002294987440109253


Epoch 20/20:  44%|████▍     | 14/32 [00:01<00:01,  9.49batch/s, loss=0.00263]

CRF Loss: 0.0018001869320869446
CRF Loss: 0.0031125545501708984


Epoch 20/20:  50%|█████     | 16/32 [00:01<00:01,  9.38batch/s, loss=0.00257]

CRF Loss: 0.002124980092048645
CRF Loss: 0.0022035837173461914


Epoch 20/20:  53%|█████▎    | 17/32 [00:01<00:01,  9.44batch/s, loss=0.00256]

CRF Loss: 0.002219080924987793
CRF Loss: 0.002847641706466675
CRF Loss: 0.0022974982857704163


Epoch 20/20:  66%|██████▌   | 21/32 [00:02<00:01,  9.88batch/s, loss=0.00269]

CRF Loss: 0.0020489096641540527
CRF Loss: 0.0033375322818756104
CRF Loss: 0.002369314432144165


Epoch 20/20:  75%|███████▌  | 24/32 [00:02<00:00,  9.76batch/s, loss=0.00272]

CRF Loss: 0.003198377788066864
CRF Loss: 0.0032598674297332764


Epoch 20/20:  81%|████████▏ | 26/32 [00:02<00:00,  9.58batch/s, loss=0.00262]

CRF Loss: 0.0024468302726745605
CRF Loss: 0.0029060840606689453


Epoch 20/20:  88%|████████▊ | 28/32 [00:02<00:00,  9.52batch/s, loss=0.0026]

CRF Loss: 0.002528965473175049
CRF Loss: 0.002125948667526245


Epoch 20/20:  91%|█████████ | 29/32 [00:03<00:00,  9.49batch/s, loss=0.00261]

CRF Loss: 0.0028299614787101746
CRF Loss: 0.00291287899017334
CRF Loss: 0.0023041069507598877


Epoch 20/20: 100%|██████████| 32/32 [00:03<00:00,  9.74batch/s, loss=0.00258]

CRF Loss: 0.0018901824951171875
cuda
Epoch 20/20 - Total Loss: 0.0826
Predictions for 'John is from New York': ['O', 'O', 'O', 'O', 'O']





In [14]:
from sklearn.metrics import classification_report
import torch

def load_data(file_path):
    sentences, tags = [], []
    with open(file_path, "r") as f:
        sentence, tag1 = [], []
        for line in f.readlines():
            if len(line.strip()) == 0:
                if sentence:
                    sentences.append(sentence)
                    tags.append(tag1)
                sentence, tag1 = [], []
            else:
                parts = line.strip().split("\t")
                if len(parts) == 2:
                    sent, tag = parts
                else:
                    sent = parts[0]
                    tag = 'O'

                sentence.append(sent)
                tag1.append(tag)
    return sentences, tags

def save_predictions(model, dataset_path, output_file, word2idx, tag2idx, device, batch_size=8):
    sentences, tags = load_data(dataset_path)
    X_data = [[word2idx.get(word, word2idx['<PAD>']) for word in sentence] for sentence in sentences]
    max_len_current = max(len(sentence) for sentence in X_data)
    X_data = [sentence + [0] * (max_len_current - len(sentence)) for sentence in X_data]

    model.eval()
    all_predictions = []

    with torch.no_grad():
        for i in range(0, len(X_data), batch_size):
            batch_X = X_data[i:min(i+batch_size, len(X_data))]
            X_tensor = torch.tensor(batch_X, dtype=torch.long).to(device)
            batch_predictions = model.predict(X_tensor)
            all_predictions.extend(batch_predictions)
            del X_tensor
            torch.cuda.empty_cache()
    idx2tag = {idx: tag for tag, idx in tag2idx.items()}
    predicted_tags = [[idx2tag[tag] for tag in pred[:len(sentences[i])]] for i, pred in enumerate(all_predictions)]

    with open(output_file, "w") as f_out:
        for sent, pred_tags in zip(sentences, predicted_tags):
            for word, tag in zip(sent, pred_tags):
                f_out.write(f"{word}\t{tag}\n")
            f_out.write("\n")

    return sentences, tags, predicted_tags

def evaluate_model(true_labels, pred_labels):
    """ Compute classification metrics: Precision, Recall, F1-score """
    true_labels_flat = [tag for sent in true_labels for tag in sent]
    pred_labels_flat = [tag for sent in pred_labels for tag in sent]

    print("\nClassification Report:")
    print(classification_report(true_labels_flat, pred_labels_flat, zero_division=0))

batch_size = 16

_, true_tags_dev, pred_tags_dev = save_predictions(
    model,
    "/content/dev",
    "dev.output_3",
    word2idx,
    tag2idx,
    device,
    batch_size
)

print("Processing test set...")
_, true_tags_test, pred_tags_test = save_predictions(
    model,
    "/content/test",
    "test.output_3",
    word2idx,
    tag2idx,
    device,
    batch_size
)

print("Evaluating dev set...")
evaluate_model(true_tags_dev, pred_tags_dev)

print("Evaluating test set...")
evaluate_model(true_tags_test, pred_tags_test)

Processing dev set...
Processing test set...
Evaluating dev set...

Classification Report:
              precision    recall  f1-score   support

       B-DNA       0.00      0.00      0.00         0
       B-RNA       0.00      0.00      0.00         0
 B-cell_line       0.00      0.00      0.00         0
 B-cell_type       0.00      0.00      0.00         0
   B-protein       0.00      0.00      0.00         0
       I-DNA       0.00      0.00      0.00         0
       I-RNA       0.00      0.00      0.00         0
 I-cell_line       0.00      0.00      0.00         0
 I-cell_type       0.00      0.00      0.00         0
   I-protein       0.00      0.00      0.00         0
           O       1.00      0.89      0.94      2800

    accuracy                           0.89      2800
   macro avg       0.09      0.08      0.09      2800
weighted avg       1.00      0.89      0.94      2800

Evaluating test set...

Classification Report:
              precision    recall  f1-score   sup