Format lại code

In [None]:
import matplotlib.pyplot as plt
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
import math
import torch.multiprocessing as mp
from transformers import BertTokenizerFast
import pickle
# Evaluate
from nltk.translate.bleu_score import sentence_bleu
# from nltk.translate.meteor_score import meteor_score
# from rogue import Rogue
# from pycocoevalcap.cider.cider import Cider
# from pycocoevalcap.spice.spice import Spice
#end
# mp.set_start_method('spawn', force=True)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
detokenize = tokenizer.convert_ids_to_tokens
batch_detokenize = tokenizer.batch_decode
BATCH_SIZE = 64 
#64 : 3.1GB VRAM b16 (3.0 Dedicated | 0.1 Shared)
device = 'cuda'
image_transforms = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Dùng cho ImageNet
])


In [None]:
def load_image(image_folder_path: str): # DataLoader trên Jupyternotebook ko xử lý đa luồng đc nên đành phải load hết vô
    result = {}
    file_counts = len(os.listdir(image_folder_path))
    progress = 0
    last_log = 0
    count = 0
    for file_name in os.listdir(image_folder_path):
        file_path = os.path.join(image_folder_path, file_name)
        if os.path.isfile(file_path):
            id = os.path.splitext(os.path.basename(file_name))[0]
            if id in result: continue
            with Image.open(file_path).convert("RGB") as img:
                image = image_transforms(img)
                result[id] = torch.Tensor(image)
            count += 1
            progress = count / file_counts * 100
            if progress - last_log >= 10:
                last_log = progress
                print(f"Image loading {progress:.2f} % ")
    if last_log != 100:
        print(f"Image loading 100 % ")
    return result
def process_data(image_dict: dict[str, torch.Tensor], processed_data_path: str):
    result: list[tuple[torch.Tensor, torch.Tensor]] = []
    with open(processed_data_path, 'rb') as file:
        processed_data = pickle.load(file)
    for image_id, caption in processed_data:
        caption = torch.tensor(caption)
        image = image_dict[image_id]
        result.append((image, caption))
    return result
def get_train_test_loader(batch_size: int, n_wokers: int = 2):
    image_path = "../Flickr8k/Flicker8k_Dataset"
    train_path = "../Data_bert/train_set_bert.pkl"
    test_path = "../Data_bert/test_set_bert.pkl"
    image_dict: dict[str, torch.Tensor] = load_image(image_path)
    train_data = process_data(image_dict, train_path)
    test_data = process_data(image_dict, test_path)
    trainloader = DataLoader(train_data, batch_size=batch_size, num_workers=n_wokers, shuffle=True)
    testloader = DataLoader(test_data, batch_size=batch_size, num_workers=n_wokers, shuffle=False)
    return trainloader, testloader
trainloader, testloader = get_train_test_loader(BATCH_SIZE, 2)

Image loading 10.01 % 
Image loading 20.02 % 
Image loading 30.03 % 
Image loading 40.04 % 
Image loading 50.06 % 
Image loading 60.07 % 
Image loading 70.08 % 
Image loading 80.09 % 
Image loading 90.10 % 
Image loading 100 % 


In [None]:
# for images, captions in trainloader:
#     for i in range(min(4, captions.shape[0])):
#         print(images[i][:,100,100])
#         print(detokenize(captions[i]))
#     break

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self, output_size: int):
        super(EncoderCNN, self).__init__()
        self.inception_model = models.inception_v3(pretrained=True)
        #self.inception_model.fc = torch.nn.Identity()
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(1000, output_size)
        for name, param in self.inception_model.named_parameters():
            if "fc.weight" in name or "fc.bias" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
    def forward(self, images: torch.Tensor):
        features = self.inception_model(images) #[1, 2048]
        if isinstance(features, tuple):  # Nếu là tuple
            features = features[0] 
        features = self.relu(features)
        features = self.dropout(features)
        features = self.fc(features)
        return features
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, vocab_size, hidden_size, input_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size + input_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.use_hidden = False
    def forward(self, features: torch.Tensor, captions: torch.Tensor, hidden_state: tuple[torch.Tensor, torch.Tensor] = None): 
        # Constant : seq = 1
        # features : [bsz, img_sz]
        # captions : [bsz, seq]
        # hidden : [num_layers, bsz, hidden]
        embeddings = self.embed(captions) # [bsz, seq, embed]
        features = features.unsqueeze(1).expand(-1, embeddings.shape[1], -1) # [bsz, seq, embed]
        combined = torch.cat((features, embeddings), dim=2) # [bsz, seq, img_sz + embed]
        # hidden_state : [num_layers, seq, hid] * 2
        if self.use_hidden:
            output, hidden_state = self.lstm(combined, hidden_state)
        else:
            output, hidden_state = self.lstm(combined)
        # output : [bsz, seq, vocab_size]
        output = self.relu(output)
        output = self.dropout(output)
        output = self.linear(output)
        return output, hidden_state
class ImageToTextModel(nn.Module):
    def __init__(self, encoder: nn.Module, decoder: nn.Module):
        super(ImageToTextModel, self).__init__()
        self.encoder: EncoderCNN = encoder
        self.decoder: DecoderRNN = decoder
    def forward(self, images: torch.Tensor, captions: torch.Tensor):
        # Constant : SEQ_LENGTH = 46, seq = 1
        # images: [bsz, 3, raw_image_width, raw_image_height]
        # captions: [bsz, SEQ_LENGTH]
        bsz = images.shape[0]
        if self.decoder.use_hidden:
            hidden_state: tuple[torch.Tensor, torch.Tensor] = None
        features = self.encoder(images)
        # features: [bsz, img_sz]
        seq_predicted = []
        seq_predicted.append(torch.zeros((bsz, self.decoder.vocab_size), dtype=torch.float32).unsqueeze(1).to(device))
        # seq_predicted : [predict_length, seq, vocab]
        decoder_input = captions[:, 0].unsqueeze(1)
        # decoder_input : [bsz, seq]
        seq_length = captions.shape[1]
        for di in range(1, seq_length):
            if self.decoder.use_hidden:
                output_decoder, hidden_state = self.decoder(features, decoder_input, hidden_state)
            else:
                output_decoder, hidden_state = self.decoder(features, decoder_input)
            # ouput_decoder: [bsz, seq, vocab]
            # hidden_state: [num_layers, bsz, hidden_size] * 2
            decoder_input = captions[:, di].unsqueeze(1)
            # decoder_input : [bsz, seq]
            seq_predicted.append(output_decoder)
        return torch.cat(seq_predicted, dim=1)
    def predict(self, images: torch.Tensor, captions: torch.Tensor, predict_length: int):
        bsz = images.shape[0]
        hidden_state: tuple[torch.Tensor, torch.Tensor] = None
        features = self.encoder(images)
        seq_predicted = []
        seq_predicted.append(torch.zeros((bsz, self.decoder.vocab_size), dtype=torch.float32).unsqueeze(1).to(device))
        decoder_input = captions
        for _ in range(predict_length-1):
            output_decoder, hidden_state = self.decoder(features, decoder_input, hidden_state)
            seq_predicted.append(output_decoder)
            decoder_input = output_decoder.argmax(2)
        return torch.cat(seq_predicted, dim=1)


In [None]:
def train(model : ImageToTextModel, dataloader : DataLoader, lossf : callable, optimizer : torch.optim.Optimizer, mixed: bool, device: str):
    model.train()
    epoch_loss = 0
    count = 0
    for images, captions in dataloader:
        optimizer.zero_grad()
        images: torch.Tensor = images.to(device)
        captions: torch.Tensor = captions.to(device)
        if mixed:
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                outputs: torch.Tensor = model(images, captions)
                loss: torch.Tensor = lossf(outputs.view(-1, outputs.shape[2]), captions.view(-1))
        else:
            outputs: torch.Tensor = model(images, captions)
            loss: torch.Tensor = lossf(outputs.view(-1, outputs.shape[2]), captions.view(-1))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        count += 1 
    return epoch_loss / count
def compute_metrics(predict: torch.Tensor, caption: torch.Tensor): # Quên chưa xóa token đầu ...
    # Accuracy
    is_correct = 1 if predict[1] == caption[1] else 0
    caption: list[str] = detokenize(caption)
    predict: list[str] = detokenize(predict)
    while (caption[-1] == "[PAD]"):
        caption.pop()
    while (predict[-1] == "[PAD]"):
        predict.pop()
    # Bleu
    reference = [caption]
    # print(caption, predict)
    bleu_1 = sentence_bleu(reference, predict, weights=(1, 0, 0, 0))
    bleu_2 = sentence_bleu(reference, predict, weights=(0.5, 0.5, 0, 0))
    bleu_3 = sentence_bleu(reference, predict, weights=(0.33, 0.33, 0.33, 0))
    bleu_4 = sentence_bleu(reference, predict)
    # Rogue
    # rogue = Rogue()
    # scores = rogue.get_scores(' '.join(predicts), ' '.join(captions))
    # Meteor
    # meteor_score_ = meteor_score([' '.join(captions)], ' '.join(predicts))
    # Cider, spice
    # cider_scorer = Cider()
    # spice_scorer = Spice()
    # cider_score, _ = cider_scorer.compute_score({0 : [' '.join(caption)]}, {0 : [' '.join(predict)]})
    # spice_score, _ = spice_scorer.compute_score({0 : [' '.join(captions)]}, {0 : [' '.join(predicts)]})
    return {
        "accuracy" : is_correct,
        "bleu_1" : bleu_1,
        "bleu_2" : bleu_2,
        "bleu_3" : bleu_3,
        "bleu_4" : bleu_4,
        # "rogue-1" : scores["rogue-1"]["f"],
        # "rogue-2" : scores["rogue-2"]["f"],
        # "rogue-l" : scores["rogue-l"]["f"],
        # "meteor" : meteor_score_,
        # "cider" : cider_score,
        # "spice" : spice_score
    }
def test(model : ImageToTextModel, dataloader : DataLoader, lossf : callable, mixed: bool, device: str):
    model.eval()
    count = 0
    total_count = 0
    epoch_metrics = {'loss' : 0}
    input = torch.tensor([101]).to(device)
    for images, captions in dataloader:
        images: torch.Tensor = images.to(device)
        captions: torch.Tensor = captions.to(device)
        bsz = images.shape[0]
        inputs = input.unsqueeze(1).expand((bsz, 1))
        if mixed:
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                outputs: torch.Tensor = model.predict(images, inputs, captions.shape[1])
                loss: torch.Tensor= lossf(outputs.view(-1, outputs.shape[2]), captions.view(-1))
        else:
            outputs: torch.Tensor  = model.predict(images, inputs, captions.shape[1])
            loss: torch.Tensor= lossf(outputs.view(-1, outputs.shape[2]), captions.view(-1))
        predicts = outputs.argmax(2)
        count += 1
        total_count += bsz
        epoch_metrics['loss'] += loss.item()
        for i in range(bsz):
            metrics = compute_metrics(predicts[i], captions[i])
            for key in metrics:
                if key not in epoch_metrics:
                    epoch_metrics[key] = metrics[key]
                else:
                    epoch_metrics[key] += metrics[key]
    for key in epoch_metrics:
        if key in ['loss']:
            epoch_metrics[key] /= count
        else:
            epoch_metrics[key] /= total_count
    return epoch_metrics

In [None]:
import util
image_size = 256
embed_size = 256
hidden_size = 256
encoder = EncoderCNN(
    output_size=image_size
)
decoder = DecoderRNN(
    embed_size=embed_size,
    vocab_size=tokenizer.vocab_size,
    hidden_size=hidden_size,
    input_size=image_size,
    num_layers=1
)
decoder.use_hidden = False
image_to_text_model = ImageToTextModel(
    encoder=encoder,
    decoder=decoder
)
loss_func= nn.CrossEntropyLoss(ignore_index=0)
model_path, train_log = util.train_eval(
    trainloader=trainloader,
    testloader=testloader,
    model=image_to_text_model,
    train_func=train,
    test_func=test,
    lossf=loss_func,
    num_epochs=50,
    lr=2.5e-3,
    gamma=0.95,
    log_step=1,
    warmup_nepochs=10,
    warmup_lr=1e-3,
    warmup_gamma=1.1,
    save=True,
    mixed_train = True,
    mixed_eval = False,
    save_path="checkpoint/teset_model",
    optimizer_type=torch.optim.Adam,
    device=device,
    metadata_extra={
        "batch_size" : BATCH_SIZE,
        "dataset_name" : "Flickr8k",
        "use_hidden" : decoder.use_hidden
    },
    log_metric=True
)



Start train


100%|██████████| 1/1 [00:06<00:00,  6.30s/it]

Train loss : 10.3357 | Test loss : 10.3010 | Train time : 3.26 s | Lr : 0.00100000
{'loss': 10.301046371459961, 'accuracy': 0.0, 'bleu_1': 0.0, 'bleu_2': 0.0, 'bleu_3': 0.0, 'bleu_4': 0.0}





Complete


In [None]:
encoder = EncoderCNN(
    output_size=image_size
)
decoder = DecoderRNN(
    embed_size=embed_size,
    vocab_size=tokenizer.vocab_size,
    hidden_size=hidden_size,
    input_size=image_size,
    num_layers=1
)
decoder.use_hidden = True
image_to_text_model = ImageToTextModel(
    encoder=encoder,
    decoder=decoder
)
loss_func= nn.CrossEntropyLoss(ignore_index=0)
model_path, train_log = util.train_eval(
    trainloader=trainloader,
    testloader=testloader,
    model=image_to_text_model,
    train_func=train,
    test_func=test,
    lossf=loss_func,
    num_epochs=50,
    lr=2.5e-3,
    gamma=0.95,
    log_step=1,
    warmup_nepochs=10,
    warmup_lr=1e-3,
    warmup_gamma=1.1,
    save=True,
    mixed_train = True,
    mixed_eval = False,
    save_path="checkpoint/teset_model",
    optimizer_type=torch.optim.Adam,
    device=device,
    metadata_extra={
        "batch_size" : BATCH_SIZE,
        "dataset_name" : "Flickr8k",
        "use_hidden" : decoder.use_hidden
    },
    log_metric=True
)

Start train


  2%|▏         | 1/50 [00:06<05:02,  6.17s/it]

Train loss : 10.3144 | Test loss : 10.2732 | Train time : 3.30 s | Lr : 0.00100000
{'loss': 10.273191452026367, 'accuracy': 0.3, 'bleu_1': 0.019565217391304346, 'bleu_2': 1.4428752331563675e-155, 'bleu_3': 1.4966750252888628e-204, 'bleu_4': 3.98358021083051e-232}


  4%|▍         | 2/50 [00:12<04:51,  6.07s/it]

Train loss : 10.1874 | Test loss : 10.1764 | Train time : 3.03 s | Lr : 0.00110000
{'loss': 10.176424980163574, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


  6%|▌         | 3/50 [00:18<04:43,  6.03s/it]

Train loss : 10.0034 | Test loss : 10.0100 | Train time : 3.02 s | Lr : 0.00121000
{'loss': 10.010049819946289, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


  8%|▊         | 4/50 [00:24<04:41,  6.11s/it]

Train loss : 9.6733 | Test loss : 9.7731 | Train time : 3.12 s | Lr : 0.00133100
{'loss': 9.773061752319336, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 10%|█         | 5/50 [00:30<04:32,  6.05s/it]

Train loss : 9.2667 | Test loss : 9.5097 | Train time : 3.07 s | Lr : 0.00146410
{'loss': 9.509742736816406, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 12%|█▏        | 6/50 [00:36<04:22,  5.97s/it]

Train loss : 8.8059 | Test loss : 9.1889 | Train time : 2.90 s | Lr : 0.00161051
{'loss': 9.188897132873535, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 14%|█▍        | 7/50 [00:41<04:12,  5.87s/it]

Train loss : 8.3497 | Test loss : 8.8821 | Train time : 2.87 s | Lr : 0.00177156
{'loss': 8.88210678100586, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 16%|█▌        | 8/50 [00:47<04:07,  5.90s/it]

Train loss : 7.8274 | Test loss : 8.5995 | Train time : 3.03 s | Lr : 0.00194872
{'loss': 8.599486351013184, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 18%|█▊        | 9/50 [00:53<04:02,  5.92s/it]

Train loss : 7.3260 | Test loss : 8.2627 | Train time : 3.03 s | Lr : 0.00214359
{'loss': 8.262710571289062, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 20%|██        | 10/50 [00:59<03:57,  5.95s/it]

Train loss : 6.7026 | Test loss : 7.9441 | Train time : 2.95 s | Lr : 0.00235795
{'loss': 7.944087028503418, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 22%|██▏       | 11/50 [01:05<03:51,  5.95s/it]

Train loss : 6.1240 | Test loss : 7.6882 | Train time : 3.11 s | Lr : 0.00259374
{'loss': 7.688172817230225, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 24%|██▍       | 12/50 [01:11<03:44,  5.92s/it]

Train loss : 5.5626 | Test loss : 7.5512 | Train time : 2.95 s | Lr : 0.00285312
{'loss': 7.551210880279541, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 26%|██▌       | 13/50 [01:17<03:39,  5.94s/it]

Train loss : 5.0281 | Test loss : 7.5918 | Train time : 3.10 s | Lr : 0.00313843
{'loss': 7.591808319091797, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 28%|██▊       | 14/50 [01:23<03:37,  6.03s/it]

Train loss : 4.6959 | Test loss : 7.8316 | Train time : 3.04 s | Lr : 0.00345227
{'loss': 7.831576824188232, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 30%|███       | 15/50 [01:29<03:32,  6.06s/it]

Train loss : 4.4431 | Test loss : 8.1356 | Train time : 3.03 s | Lr : 0.00379750
{'loss': 8.135635375976562, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 32%|███▏      | 16/50 [01:35<03:26,  6.06s/it]

Train loss : 4.3365 | Test loss : 8.5048 | Train time : 3.02 s | Lr : 0.00400000
{'loss': 8.50479507446289, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 34%|███▍      | 17/50 [01:42<03:21,  6.09s/it]

Train loss : 4.2177 | Test loss : 8.7724 | Train time : 3.19 s | Lr : 0.00360000
{'loss': 8.772432327270508, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 36%|███▌      | 18/50 [01:48<03:16,  6.13s/it]

Train loss : 4.2735 | Test loss : 9.0522 | Train time : 3.10 s | Lr : 0.00324000
{'loss': 9.052202224731445, 'accuracy': 0.0, 'bleu_1': 0.021739130434782605, 'bleu_2': 2.199344694155892e-155, 'bleu_3': 2.514995661874996e-204, 'bleu_4': 6.995501686664744e-232}


 38%|███▊      | 19/50 [01:54<03:08,  6.09s/it]

Train loss : 4.2747 | Test loss : 9.2094 | Train time : 3.09 s | Lr : 0.00291600
{'loss': 9.209364891052246, 'accuracy': 0.0, 'bleu_1': 0.021739130434782605, 'bleu_2': 2.199344694155892e-155, 'bleu_3': 2.514995661874996e-204, 'bleu_4': 6.995501686664744e-232}


 40%|████      | 20/50 [02:00<03:04,  6.14s/it]

Train loss : 4.1936 | Test loss : 9.3552 | Train time : 3.23 s | Lr : 0.00262440
{'loss': 9.355242729187012, 'accuracy': 0.0, 'bleu_1': 0.021739130434782605, 'bleu_2': 2.199344694155892e-155, 'bleu_3': 2.514995661874996e-204, 'bleu_4': 6.995501686664744e-232}


 42%|████▏     | 21/50 [02:06<02:55,  6.06s/it]

Train loss : 4.2050 | Test loss : 9.5152 | Train time : 2.97 s | Lr : 0.00236196
{'loss': 9.515209197998047, 'accuracy': 0.0, 'bleu_1': 0.021739130434782605, 'bleu_2': 2.199344694155892e-155, 'bleu_3': 2.514995661874996e-204, 'bleu_4': 6.995501686664744e-232}


 44%|████▍     | 22/50 [02:12<02:49,  6.04s/it]

Train loss : 4.2907 | Test loss : 9.6038 | Train time : 3.05 s | Lr : 0.00212576
{'loss': 9.603788375854492, 'accuracy': 0.0, 'bleu_1': 0.021739130434782605, 'bleu_2': 2.199344694155892e-155, 'bleu_3': 2.514995661874996e-204, 'bleu_4': 6.995501686664744e-232}


 46%|████▌     | 23/50 [02:18<02:42,  6.02s/it]

Train loss : 4.2340 | Test loss : 9.7008 | Train time : 3.06 s | Lr : 0.00191319
{'loss': 9.70083999633789, 'accuracy': 0.0, 'bleu_1': 0.021739130434782605, 'bleu_2': 2.199344694155892e-155, 'bleu_3': 2.514995661874996e-204, 'bleu_4': 6.995501686664744e-232}


 48%|████▊     | 24/50 [02:24<02:36,  6.02s/it]

Train loss : 4.2024 | Test loss : 9.7346 | Train time : 2.93 s | Lr : 0.00172187
{'loss': 9.734627723693848, 'accuracy': 0.0, 'bleu_1': 0.021739130434782605, 'bleu_2': 2.199344694155892e-155, 'bleu_3': 2.514995661874996e-204, 'bleu_4': 6.995501686664744e-232}


 50%|█████     | 25/50 [02:30<02:30,  6.02s/it]

Train loss : 4.2106 | Test loss : 9.8102 | Train time : 3.13 s | Lr : 0.00154968
{'loss': 9.810232162475586, 'accuracy': 0.7, 'bleu_1': 0.03695652173913043, 'bleu_2': 2.8370435746127058e-155, 'bleu_3': 2.9674669272513624e-204, 'bleu_4': 7.922020771156347e-232}


 52%|█████▏    | 26/50 [02:36<02:23,  5.96s/it]

Train loss : 4.1803 | Test loss : 9.8972 | Train time : 3.02 s | Lr : 0.00139471
{'loss': 9.897170066833496, 'accuracy': 0.7, 'bleu_1': 0.04782608695652173, 'bleu_2': 3.1755883016109695e-155, 'bleu_3': 3.1845015165594013e-204, 'bleu_4': 8.345673348941911e-232}


 54%|█████▍    | 27/50 [02:42<02:16,  5.95s/it]

Train loss : 4.1607 | Test loss : 9.9441 | Train time : 3.00 s | Lr : 0.00125524
{'loss': 9.944064140319824, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 56%|█████▌    | 28/50 [02:48<02:10,  5.91s/it]

Train loss : 4.2309 | Test loss : 9.9747 | Train time : 2.94 s | Lr : 0.00112972
{'loss': 9.974706649780273, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 58%|█████▊    | 29/50 [02:53<02:03,  5.90s/it]

Train loss : 4.1821 | Test loss : 9.9916 | Train time : 2.92 s | Lr : 0.00101675
{'loss': 9.991578102111816, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 60%|██████    | 30/50 [03:00<01:59,  5.97s/it]

Train loss : 4.1876 | Test loss : 9.9873 | Train time : 3.04 s | Lr : 0.00091507
{'loss': 9.987344741821289, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 62%|██████▏   | 31/50 [03:06<01:53,  5.98s/it]

Train loss : 4.2264 | Test loss : 9.9742 | Train time : 3.06 s | Lr : 0.00082356
{'loss': 9.974248886108398, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 64%|██████▍   | 32/50 [03:12<01:47,  5.99s/it]

Train loss : 4.2001 | Test loss : 9.9770 | Train time : 3.05 s | Lr : 0.00074121
{'loss': 9.977022171020508, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 66%|██████▌   | 33/50 [03:18<01:41,  6.00s/it]

Train loss : 4.1882 | Test loss : 9.9868 | Train time : 3.06 s | Lr : 0.00066709
{'loss': 9.986806869506836, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 68%|██████▊   | 34/50 [03:24<01:36,  6.01s/it]

Train loss : 4.1701 | Test loss : 9.9968 | Train time : 3.08 s | Lr : 0.00060038
{'loss': 9.996774673461914, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 70%|███████   | 35/50 [03:30<01:31,  6.10s/it]

Train loss : 4.2252 | Test loss : 10.0056 | Train time : 3.23 s | Lr : 0.00054034
{'loss': 10.005646705627441, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 72%|███████▏  | 36/50 [03:36<01:25,  6.08s/it]

Train loss : 4.2640 | Test loss : 10.0134 | Train time : 3.04 s | Lr : 0.00048631
{'loss': 10.013350486755371, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 74%|███████▍  | 37/50 [03:42<01:18,  6.05s/it]

Train loss : 4.1567 | Test loss : 10.0195 | Train time : 3.01 s | Lr : 0.00043768
{'loss': 10.019546508789062, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 76%|███████▌  | 38/50 [03:48<01:12,  6.04s/it]

Train loss : 4.2176 | Test loss : 10.0253 | Train time : 3.09 s | Lr : 0.00039391
{'loss': 10.025252342224121, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 78%|███████▊  | 39/50 [03:54<01:06,  6.05s/it]

Train loss : 4.2167 | Test loss : 10.0317 | Train time : 3.14 s | Lr : 0.00035452
{'loss': 10.031721115112305, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 80%|████████  | 40/50 [04:00<01:00,  6.00s/it]

Train loss : 4.2174 | Test loss : 10.0372 | Train time : 2.99 s | Lr : 0.00031907
{'loss': 10.03724479675293, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 82%|████████▏ | 41/50 [04:06<00:54,  6.01s/it]

Train loss : 4.1973 | Test loss : 10.0444 | Train time : 3.07 s | Lr : 0.00028716
{'loss': 10.044437408447266, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 84%|████████▍ | 42/50 [04:12<00:48,  6.02s/it]

Train loss : 4.2064 | Test loss : 10.0514 | Train time : 3.07 s | Lr : 0.00025844
{'loss': 10.05142879486084, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 86%|████████▌ | 43/50 [04:18<00:42,  6.03s/it]

Train loss : 4.1750 | Test loss : 10.0564 | Train time : 3.07 s | Lr : 0.00023260
{'loss': 10.056373596191406, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 88%|████████▊ | 44/50 [04:24<00:36,  6.03s/it]

Train loss : 4.2377 | Test loss : 10.0604 | Train time : 3.14 s | Lr : 0.00020934
{'loss': 10.060368537902832, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 90%|█████████ | 45/50 [04:30<00:30,  6.01s/it]

Train loss : 4.1969 | Test loss : 10.0619 | Train time : 2.98 s | Lr : 0.00018841
{'loss': 10.06187629699707, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 92%|█████████▏| 46/50 [04:36<00:23,  5.99s/it]

Train loss : 4.2390 | Test loss : 10.0624 | Train time : 3.01 s | Lr : 0.00016956
{'loss': 10.062381744384766, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 94%|█████████▍| 47/50 [04:42<00:17,  5.99s/it]

Train loss : 4.2027 | Test loss : 10.0618 | Train time : 3.06 s | Lr : 0.00015261
{'loss': 10.061808586120605, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 96%|█████████▌| 48/50 [04:48<00:11,  5.99s/it]

Train loss : 4.2162 | Test loss : 10.0611 | Train time : 3.05 s | Lr : 0.00013735
{'loss': 10.061114311218262, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


 98%|█████████▊| 49/50 [04:54<00:06,  6.05s/it]

Train loss : 4.1819 | Test loss : 10.0608 | Train time : 3.28 s | Lr : 0.00012361
{'loss': 10.060831069946289, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}


100%|██████████| 50/50 [05:00<00:00,  6.01s/it]

Train loss : 4.1630 | Test loss : 10.0601 | Train time : 3.04 s | Lr : 0.00011125
{'loss': 10.060079574584961, 'accuracy': 0.7, 'bleu_1': 0.02608695652173913, 'bleu_2': 1.973844012052805e-155, 'bleu_3': 2.0643129098604857e-204, 'bleu_4': 5.5150404173765435e-232}





Complete
