Format lại code

In [1]:
import matplotlib.pyplot as plt
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
import math
import torch.multiprocessing as mp
from transformers import BertTokenizerFast
import pickle
# Evaluate
from nltk.translate.bleu_score import sentence_bleu
# from nltk.translate.meteor_score import meteor_score
# from rogue import Rogue
# from pycocoevalcap.cider.cider import Cider
# from pycocoevalcap.spice.spice import Spice
#end
# mp.set_start_method('spawn', force=True)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
detokenize = tokenizer.convert_ids_to_tokens
batch_detokenize = tokenizer.batch_decode
BATCH_SIZE = 64 
#64 : 3.1GB VRAM b16 (3.0 Dedicated | 0.1 Shared)
device = 'cuda'
image_transforms = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Dùng cho ImageNet
])


In [2]:
def load_image(image_folder_path: str): # DataLoader trên Jupyternotebook ko xử lý đa luồng đc nên đành phải load hết vô
    result = {}
    file_counts = len(os.listdir(image_folder_path))
    progress = 0
    last_log = 0
    count = 0
    for file_name in os.listdir(image_folder_path):
        file_path = os.path.join(image_folder_path, file_name)
        if os.path.isfile(file_path):
            id = os.path.splitext(os.path.basename(file_name))[0]
            if id in result: continue
            with Image.open(file_path).convert("RGB") as img:
                image = image_transforms(img)
                result[id] = torch.Tensor(image)
            count += 1
            progress = count / file_counts * 100
            if progress - last_log >= 10:
                last_log = progress
                print(f"Image loading {progress:.2f} % ")
    if last_log != 100:
        print(f"Image loading 100 % ")
    return result
def process_data(image_dict: dict[str, torch.Tensor], processed_data_path: str):
    result: list[tuple[torch.Tensor, torch.Tensor]] = []
    with open(processed_data_path, 'rb') as file:
        processed_data = pickle.load(file)
    for image_id, caption in processed_data:
        caption = torch.tensor(caption)
        image = image_dict[image_id]
        result.append((image, caption))
    return result
def get_train_test_loader(batch_size: int, n_wokers: int = 2):
    image_path = "../Flickr8k/Flicker8k_Dataset"
    train_path = "../Data_bert/train_set_bert.pkl"
    test_path = "../Data_bert/test_set_bert.pkl"
    image_dict: dict[str, torch.Tensor] = load_image(image_path)
    train_data = process_data(image_dict, train_path)
    test_data = process_data(image_dict, test_path)
    trainloader = DataLoader(train_data, batch_size=batch_size, num_workers=n_wokers, shuffle=True)
    testloader = DataLoader(test_data, batch_size=batch_size, num_workers=n_wokers, shuffle=False)
    return trainloader, testloader
trainloader, testloader = get_train_test_loader(BATCH_SIZE, 2)

Image loading 10.01 % 
Image loading 20.02 % 
Image loading 30.03 % 
Image loading 40.04 % 
Image loading 50.06 % 
Image loading 60.07 % 
Image loading 70.08 % 
Image loading 80.09 % 
Image loading 90.10 % 
Image loading 100 % 


In [3]:
# for images, captions in trainloader:
#     for i in range(min(4, captions.shape[0])):
#         print(images[i][:,100,100])
#         print(detokenize(captions[i]))
#     break

In [4]:
class EncoderCNN(nn.Module):
    def __init__(self, output_size: int):
        super(EncoderCNN, self).__init__()
        self.inception_model = models.inception_v3(pretrained=True)
        #self.inception_model.fc = torch.nn.Identity()
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(1000, output_size)
        for name, param in self.inception_model.named_parameters():
            if "fc.weight" in name or "fc.bias" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
    def forward(self, images: torch.Tensor):
        features = self.inception_model(images) #[1, 2048]
        if isinstance(features, tuple):  # Nếu là tuple
            features = features[0] 
        features = self.relu(features)
        features = self.dropout(features)
        features = self.fc(features)
        return features
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, vocab_size, hidden_size, input_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size + input_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.use_hidden = False
    def forward(self, features: torch.Tensor, captions: torch.Tensor, hidden_state: tuple[torch.Tensor, torch.Tensor] = None): 
        # Constant : seq = 1
        # features : [bsz, img_sz]
        # captions : [bsz, seq]
        # hidden : [num_layers, bsz, hidden]
        embeddings = self.embed(captions) # [bsz, seq, embed]
        features = features.unsqueeze(1).expand(-1, embeddings.shape[1], -1) # [bsz, seq, embed]
        combined = torch.cat((features, embeddings), dim=2) # [bsz, seq, img_sz + embed]
        # hidden_state : [num_layers, seq, hid] * 2
        if self.use_hidden:
            output, hidden_state = self.lstm(combined, hidden_state)
        else:
            output, hidden_state = self.lstm(combined)
        # output : [bsz, seq, vocab_size]
        output = self.relu(output)
        output = self.dropout(output)
        output = self.linear(output)
        return output, hidden_state
class ImageToTextModel(nn.Module):
    def __init__(self, encoder: nn.Module, decoder: nn.Module):
        super(ImageToTextModel, self).__init__()
        self.encoder: EncoderCNN = encoder
        self.decoder: DecoderRNN = decoder
    def forward(self, images: torch.Tensor, captions: torch.Tensor):
        # Constant : SEQ_LENGTH = 46, seq = 1
        # images: [bsz, 3, raw_image_width, raw_image_height]
        # captions: [bsz, SEQ_LENGTH]
        bsz = images.shape[0]
        if self.decoder.use_hidden:
            hidden_state: tuple[torch.Tensor, torch.Tensor] = None
        features = self.encoder(images)
        # features: [bsz, img_sz]
        seq_predicted = []
        seq_predicted.append(torch.zeros((bsz, self.decoder.vocab_size), dtype=torch.float32).unsqueeze(1).to(device))
        # seq_predicted : [predict_length, seq, vocab]
        decoder_input = captions[:, 0].unsqueeze(1)
        # decoder_input : [bsz, seq]
        seq_length = captions.shape[1]
        for di in range(1, seq_length):
            if self.decoder.use_hidden:
                output_decoder, hidden_state = self.decoder(features, decoder_input, hidden_state)
            else:
                output_decoder, hidden_state = self.decoder(features, decoder_input)
            # ouput_decoder: [bsz, seq, vocab]
            # hidden_state: [num_layers, bsz, hidden_size] * 2
            decoder_input = captions[:, di].unsqueeze(1)
            # decoder_input : [bsz, seq]
            seq_predicted.append(output_decoder)
        return torch.cat(seq_predicted, dim=1)
    def predict(self, images: torch.Tensor, captions: torch.Tensor, predict_length: int):
        bsz = images.shape[0]
        hidden_state: tuple[torch.Tensor, torch.Tensor] = None
        features = self.encoder(images)
        seq_predicted = []
        seq_predicted.append(torch.zeros((bsz, self.decoder.vocab_size), dtype=torch.float32).unsqueeze(1).to(device))
        decoder_input = captions
        for _ in range(predict_length-1):
            output_decoder, hidden_state = self.decoder(features, decoder_input, hidden_state)
            seq_predicted.append(output_decoder)
            decoder_input = output_decoder.argmax(2)
        return torch.cat(seq_predicted, dim=1)


In [5]:
def train(model : ImageToTextModel, dataloader : DataLoader, lossf : callable, optimizer : torch.optim.Optimizer, mixed: bool, device: str):
    model.train()
    epoch_loss = 0
    count = 0
    for images, captions in dataloader:
        optimizer.zero_grad()
        images: torch.Tensor = images.to(device)
        captions: torch.Tensor = captions.to(device)
        if mixed:
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                outputs: torch.Tensor = model(images, captions)
                loss: torch.Tensor = lossf(outputs.view(-1, outputs.shape[2]), captions.view(-1))
        else:
            outputs: torch.Tensor = model(images, captions)
            loss: torch.Tensor = lossf(outputs.view(-1, outputs.shape[2]), captions.view(-1))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        count += 1 
    return epoch_loss / count
def compute_metrics(predict: torch.Tensor, caption: torch.Tensor): # Quên chưa xóa token đầu ...
    # Accuracy
    is_correct = 1 if predict[1] == caption[1] else 0
    caption: list[str] = detokenize(caption)
    predict: list[str] = detokenize(predict)
    while (caption[-1] == "[PAD]"):
        caption.pop()
    while (predict[-1] == "[PAD]"):
        predict.pop()
    # Bleu
    reference = [caption]
    # print(caption, predict)
    bleu_1 = sentence_bleu(reference, predict, weights=(1, 0, 0, 0))
    bleu_2 = sentence_bleu(reference, predict, weights=(0.5, 0.5, 0, 0))
    bleu_3 = sentence_bleu(reference, predict, weights=(0.33, 0.33, 0.33, 0))
    bleu_4 = sentence_bleu(reference, predict)
    # Rogue
    # rogue = Rogue()
    # scores = rogue.get_scores(' '.join(predicts), ' '.join(captions))
    # Meteor
    # meteor_score_ = meteor_score([' '.join(captions)], ' '.join(predicts))
    # Cider, spice
    # cider_scorer = Cider()
    # spice_scorer = Spice()
    # cider_score, _ = cider_scorer.compute_score({0 : [' '.join(caption)]}, {0 : [' '.join(predict)]})
    # spice_score, _ = spice_scorer.compute_score({0 : [' '.join(captions)]}, {0 : [' '.join(predicts)]})
    return {
        "accuracy" : is_correct,
        "bleu_1" : bleu_1,
        "bleu_2" : bleu_2,
        "bleu_3" : bleu_3,
        "bleu_4" : bleu_4,
        # "rogue-1" : scores["rogue-1"]["f"],
        # "rogue-2" : scores["rogue-2"]["f"],
        # "rogue-l" : scores["rogue-l"]["f"],
        # "meteor" : meteor_score_,
        # "cider" : cider_score,
        # "spice" : spice_score
    }
def test(model : ImageToTextModel, dataloader : DataLoader, lossf : callable, mixed: bool, device: str):
    model.eval()
    count = 0
    total_count = 0
    epoch_metrics = {'loss' : 0}
    input = torch.tensor([101]).to(device)
    for images, captions in dataloader:
        images: torch.Tensor = images.to(device)
        captions: torch.Tensor = captions.to(device)
        bsz = images.shape[0]
        inputs = input.unsqueeze(1).expand((bsz, 1))
        if mixed:
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                outputs: torch.Tensor = model.predict(images, inputs, captions.shape[1])
                loss: torch.Tensor= lossf(outputs.view(-1, outputs.shape[2]), captions.view(-1))
        else:
            outputs: torch.Tensor  = model.predict(images, inputs, captions.shape[1])
            loss: torch.Tensor= lossf(outputs.view(-1, outputs.shape[2]), captions.view(-1))
        predicts = outputs.argmax(2)
        count += 1
        total_count += bsz
        epoch_metrics['loss'] += loss.item()
        for i in range(bsz):
            metrics = compute_metrics(predicts[i], captions[i])
            for key in metrics:
                if key not in epoch_metrics:
                    epoch_metrics[key] = metrics[key]
                else:
                    epoch_metrics[key] += metrics[key]
    for key in epoch_metrics:
        if key in ['loss']:
            epoch_metrics[key] /= count
        else:
            epoch_metrics[key] /= total_count
    return epoch_metrics

In [6]:
import util
image_size = 256
embed_size = 256
hidden_size = 256
encoder = EncoderCNN(
    output_size=image_size
)
decoder = DecoderRNN(
    embed_size=embed_size,
    vocab_size=tokenizer.vocab_size,
    hidden_size=hidden_size,
    input_size=image_size,
    num_layers=1
)
decoder.use_hidden = True
image_to_text_model = ImageToTextModel(
    encoder=encoder,
    decoder=decoder
)
loss_func= nn.CrossEntropyLoss(ignore_index=0)
model_path, train_log = util.train_eval(
    trainloader=trainloader, # train data loader
    testloader=testloader, # test data loader
    model=image_to_text_model, # model
    train_func=train, # hàm train
    test_func=test, # hàm test
    lossf=loss_func, # hàm loss
    end_epoch=50, # số epoch kết thúc train (tăng lên khi train từ checkpoint nhá, ko phải số epoch nó chạy code mà số epoch nó sẽ dừng)
    lr=2.5e-3, # lr sau khi warmup
    gamma=0.95, # lr * gamma ^ epoch sau warmup
    log_step=1, # in ra console mỗi k epochs
    warmup_nepochs=10, # số epoch chạy warmup scheduler
    warmup_lr=1e-3, # lr ban đầu của warmup
    warmup_gamma=1.1, # lr * gamma ^ epoch khi chạy warmup
    save_weight=True, # lưu trọng số mô hình
    save_full=True, # lưu optimizer, scheduler để train
    save_each=1, # luu mỗi k epoch
    mixed_train = True, # sử dụng bf16 cho train
    mixed_eval = False, # sử dùng bf16 cho eval
    # load_checkpoint = True, # load checkpoint cũ ko
    # load_optimizer = True, # load optimizer cũ ko (sau khi load checkpoint)
    # checkpoint_path = "checkpoint/i256_e256_h256_model/2024-12-12_20-37-10", # folder path của checkpoint
    save_path="checkpoint/i256_e256_h_256_model", # folder path lưu mô hình
    optimizer_type=torch.optim.Adam, # loại optimizer sử dụng
    device=device, # cpu hay cuda
    metadata_extra={ # tham số lưu thêm vào metada khi lưu mô hình
        "batch_size" : BATCH_SIZE,
        "dataset_name" : "Flickr8k",
        "use_hidden" : decoder.use_hidden
    },
    log_metric=True # in các metrics ra cốnle
)



Start train


The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


Train loss : 6.0180 | Test loss : 5.7932 | Train time : 241.04 s | Lr : 0.00100000
{'loss': 5.793184793448146, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


  2%|▏         | 1/50 [04:50<3:57:04, 290.29s/it]

Train loss : 5.8286 | Test loss : 5.7891 | Train time : 199.50 s | Lr : 0.00110000
{'loss': 5.789072833483732, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


  4%|▍         | 2/50 [08:51<3:29:17, 261.62s/it]

Train loss : 5.8183 | Test loss : 5.7869 | Train time : 193.56 s | Lr : 0.00121000
{'loss': 5.786935649340665, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


  6%|▌         | 3/50 [12:48<3:16:00, 250.22s/it]

Train loss : 5.8170 | Test loss : 5.7934 | Train time : 204.33 s | Lr : 0.00133100
{'loss': 5.793391927888122, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


  8%|▊         | 4/50 [16:55<3:10:42, 248.76s/it]

Train loss : 5.8202 | Test loss : 5.7950 | Train time : 192.82 s | Lr : 0.00146410
{'loss': 5.794977375223667, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 10%|█         | 5/50 [20:49<3:02:39, 243.54s/it]

Train loss : 5.8224 | Test loss : 5.8009 | Train time : 196.89 s | Lr : 0.00161051
{'loss': 5.800944998294493, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 12%|█▏        | 6/50 [24:47<2:57:13, 241.67s/it]

Train loss : 5.8250 | Test loss : 5.8041 | Train time : 190.59 s | Lr : 0.00177156
{'loss': 5.804094157641447, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 14%|█▍        | 7/50 [28:39<2:51:00, 238.62s/it]

Train loss : 5.8279 | Test loss : 5.8050 | Train time : 195.38 s | Lr : 0.00194872
{'loss': 5.804992567134809, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 16%|█▌        | 8/50 [32:36<2:46:37, 238.03s/it]

Train loss : 5.8638 | Test loss : 5.8103 | Train time : 195.83 s | Lr : 0.00214359
{'loss': 5.810324445555482, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 18%|█▊        | 9/50 [36:33<2:42:28, 237.77s/it]

Train loss : 5.8624 | Test loss : 5.8164 | Train time : 196.72 s | Lr : 0.00235795
{'loss': 5.816389011431344, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 20%|██        | 10/50 [40:31<2:38:37, 237.93s/it]

Train loss : 5.8566 | Test loss : 5.8165 | Train time : 190.76 s | Lr : 0.00250000
{'loss': 5.816462607323369, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 22%|██▏       | 11/50 [44:23<2:33:26, 236.08s/it]

Train loss : 5.8495 | Test loss : 5.8085 | Train time : 194.22 s | Lr : 0.00237500
{'loss': 5.808523534219476, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 24%|██▍       | 12/50 [48:19<2:29:28, 236.02s/it]

Train loss : 5.8371 | Test loss : 5.8076 | Train time : 196.21 s | Lr : 0.00225625
{'loss': 5.807552198820476, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 26%|██▌       | 13/50 [52:16<2:25:44, 236.33s/it]

Train loss : 5.8288 | Test loss : 5.8005 | Train time : 194.76 s | Lr : 0.00214344
{'loss': 5.800544871559626, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 28%|██▊       | 14/50 [56:12<2:21:47, 236.31s/it]

Train loss : 5.8209 | Test loss : 5.7967 | Train time : 192.83 s | Lr : 0.00203627
{'loss': 5.796684989446326, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 30%|███       | 15/50 [1:00:06<2:17:23, 235.52s/it]

Train loss : 5.8152 | Test loss : 5.7942 | Train time : 198.06 s | Lr : 0.00193445
{'loss': 5.794190907780128, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 32%|███▏      | 16/50 [1:04:05<2:14:05, 236.62s/it]

Train loss : 5.8062 | Test loss : 5.7928 | Train time : 195.86 s | Lr : 0.00183773
{'loss': 5.7927530626707435, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 34%|███▍      | 17/50 [1:08:02<2:10:13, 236.76s/it]

Train loss : 5.7997 | Test loss : 5.7893 | Train time : 194.11 s | Lr : 0.00174584
{'loss': 5.789262838001493, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 36%|███▌      | 18/50 [1:11:58<2:06:04, 236.39s/it]

Train loss : 5.7949 | Test loss : 5.7863 | Train time : 193.45 s | Lr : 0.00165855
{'loss': 5.786277517487731, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 38%|███▊      | 19/50 [1:15:53<2:01:52, 235.88s/it]

Train loss : 5.7873 | Test loss : 5.7834 | Train time : 194.18 s | Lr : 0.00157562
{'loss': 5.783353805541992, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 40%|████      | 20/50 [1:19:48<1:57:53, 235.79s/it]

Train loss : 5.7831 | Test loss : 5.7781 | Train time : 189.78 s | Lr : 0.00149684
{'loss': 5.778119153614286, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 42%|████▏     | 21/50 [1:23:39<1:53:16, 234.37s/it]

Train loss : 5.7760 | Test loss : 5.7791 | Train time : 192.52 s | Lr : 0.00142200
{'loss': 5.779054134706907, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 44%|████▍     | 22/50 [1:27:33<1:49:17, 234.19s/it]

Train loss : 5.7746 | Test loss : 5.7758 | Train time : 189.47 s | Lr : 0.00135090
{'loss': 5.775839920285382, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 46%|████▌     | 23/50 [1:31:23<1:44:52, 233.05s/it]

Train loss : 5.7677 | Test loss : 5.7754 | Train time : 197.04 s | Lr : 0.00128336
{'loss': 5.775412982023215, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 48%|████▊     | 24/50 [1:35:21<1:41:35, 234.45s/it]

Train loss : 5.7651 | Test loss : 5.7715 | Train time : 189.57 s | Lr : 0.00121919
{'loss': 5.771456175212618, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 50%|█████     | 25/50 [1:39:12<1:37:13, 233.34s/it]

Train loss : 5.7605 | Test loss : 5.7696 | Train time : 192.35 s | Lr : 0.00115823
{'loss': 5.769601695145233, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 52%|█████▏    | 26/50 [1:43:05<1:33:19, 233.32s/it]

Train loss : 5.7558 | Test loss : 5.7682 | Train time : 189.70 s | Lr : 0.00110032
{'loss': 5.768171346640285, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 54%|█████▍    | 27/50 [1:46:57<1:29:14, 232.80s/it]

Train loss : 5.7539 | Test loss : 5.7669 | Train time : 192.64 s | Lr : 0.00104530
{'loss': 5.766906768460817, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 56%|█████▌    | 28/50 [1:50:51<1:25:27, 233.07s/it]

Train loss : 5.7517 | Test loss : 5.7654 | Train time : 192.32 s | Lr : 0.00099304
{'loss': 5.765393709834618, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 58%|█████▊    | 29/50 [1:54:44<1:21:38, 233.24s/it]

Train loss : 5.7473 | Test loss : 5.7655 | Train time : 190.69 s | Lr : 0.00094338
{'loss': 5.765492602239681, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 60%|██████    | 30/50 [1:58:36<1:17:37, 232.90s/it]

Train loss : 5.7439 | Test loss : 5.7614 | Train time : 189.14 s | Lr : 0.00089621
{'loss': 5.761425971984863, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 62%|██████▏   | 31/50 [2:02:27<1:13:33, 232.28s/it]

Train loss : 5.7408 | Test loss : 5.7631 | Train time : 198.36 s | Lr : 0.00085140
{'loss': 5.763111265399788, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 64%|██████▍   | 32/50 [2:06:27<1:10:22, 234.61s/it]

Train loss : 5.7385 | Test loss : 5.7599 | Train time : 193.09 s | Lr : 0.00080883
{'loss': 5.759867438787146, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 66%|██████▌   | 33/50 [2:10:22<1:06:27, 234.56s/it]

Train loss : 5.7366 | Test loss : 5.7585 | Train time : 190.47 s | Lr : 0.00076839
{'loss': 5.7585116217408, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 68%|██████▊   | 34/50 [2:14:14<1:02:20, 233.80s/it]

Train loss : 5.7340 | Test loss : 5.7584 | Train time : 191.16 s | Lr : 0.00072997
{'loss': 5.758410556406915, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 70%|███████   | 35/50 [2:18:06<58:20, 233.40s/it]  

Train loss : 5.7306 | Test loss : 5.7559 | Train time : 194.85 s | Lr : 0.00069347
{'loss': 5.755914615679391, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 72%|███████▏  | 36/50 [2:22:02<54:38, 234.18s/it]

Train loss : 5.7295 | Test loss : 5.7572 | Train time : 191.62 s | Lr : 0.00065880
{'loss': 5.757190945782239, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 74%|███████▍  | 37/50 [2:25:55<50:40, 233.89s/it]

Train loss : 5.7264 | Test loss : 5.7557 | Train time : 193.55 s | Lr : 0.00062586
{'loss': 5.755673106712631, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 76%|███████▌  | 38/50 [2:29:50<46:48, 234.07s/it]

Train loss : 5.7270 | Test loss : 5.7544 | Train time : 196.69 s | Lr : 0.00059457
{'loss': 5.754362613339968, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 78%|███████▊  | 39/50 [2:33:48<43:07, 235.18s/it]

Train loss : 5.7232 | Test loss : 5.7538 | Train time : 195.62 s | Lr : 0.00056484
{'loss': 5.7537515555756, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 80%|████████  | 40/50 [2:37:44<39:16, 235.66s/it]

Train loss : 5.7228 | Test loss : 5.7527 | Train time : 193.19 s | Lr : 0.00053660
{'loss': 5.752682281445853, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 82%|████████▏ | 41/50 [2:41:39<35:17, 235.26s/it]

Train loss : 5.7201 | Test loss : 5.7529 | Train time : 190.19 s | Lr : 0.00050977
{'loss': 5.752877887291245, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 84%|████████▍ | 42/50 [2:45:30<31:12, 234.12s/it]

Train loss : 5.7183 | Test loss : 5.7517 | Train time : 192.41 s | Lr : 0.00048428
{'loss': 5.751736139949364, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 86%|████████▌ | 43/50 [2:49:24<27:19, 234.19s/it]

Train loss : 5.7175 | Test loss : 5.7516 | Train time : 193.96 s | Lr : 0.00046006
{'loss': 5.7516140877446045, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 88%|████████▊ | 44/50 [2:53:20<23:27, 234.54s/it]

Train loss : 5.7147 | Test loss : 5.7518 | Train time : 193.93 s | Lr : 0.00043706
{'loss': 5.7517850006682965, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 90%|█████████ | 45/50 [2:57:15<19:32, 234.59s/it]

Train loss : 5.7150 | Test loss : 5.7497 | Train time : 193.79 s | Lr : 0.00041521
{'loss': 5.749671730814101, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 92%|█████████▏| 46/50 [3:01:10<15:39, 234.85s/it]

Train loss : 5.7136 | Test loss : 5.7494 | Train time : 193.86 s | Lr : 0.00039445
{'loss': 5.749363301675531, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 94%|█████████▍| 47/50 [3:05:06<11:45, 235.10s/it]

Train loss : 5.7124 | Test loss : 5.7491 | Train time : 191.91 s | Lr : 0.00037473
{'loss': 5.749057932745052, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 96%|█████████▌| 48/50 [3:08:58<07:48, 234.39s/it]

Train loss : 5.7117 | Test loss : 5.7492 | Train time : 190.05 s | Lr : 0.00035599
{'loss': 5.749236342273181, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


100%|██████████| 50/50 [3:16:46<00:00, 236.13s/it]

Train loss : 5.7096 | Test loss : 5.7485 | Train time : 194.53 s | Lr : 0.00033819
{'loss': 5.748467173757432, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}





Complete


In [7]:
import util
image_size = 256
embed_size = 64
hidden_size = 256
encoder = EncoderCNN(
    output_size=image_size
)
decoder = DecoderRNN(
    embed_size=embed_size,
    vocab_size=tokenizer.vocab_size,
    hidden_size=hidden_size,
    input_size=image_size,
    num_layers=1
)
decoder.use_hidden = True
image_to_text_model = ImageToTextModel(
    encoder=encoder,
    decoder=decoder
)
loss_func= nn.CrossEntropyLoss(ignore_index=0)
model_path, train_log = util.train_eval(
    trainloader=trainloader, # train data loader
    testloader=testloader, # test data loader
    model=image_to_text_model, # model
    train_func=train, # hàm train
    test_func=test, # hàm test
    lossf=loss_func, # hàm loss
    end_epoch=50, # số epoch kết thúc train (tăng lên khi train từ checkpoint nhá, ko phải số epoch nó chạy code mà số epoch nó sẽ dừng)
    lr=2.5e-3, # lr sau khi warmup
    gamma=0.95, # lr * gamma ^ epoch sau warmup
    log_step=1, # in ra console mỗi k epochs
    warmup_nepochs=10, # số epoch chạy warmup scheduler
    warmup_lr=1e-3, # lr ban đầu của warmup
    warmup_gamma=1.1, # lr * gamma ^ epoch khi chạy warmup
    save_weight=True, # lưu trọng số mô hình
    save_full=True, # lưu optimizer, scheduler để train
    save_each=1, # luu mỗi k epoch
    mixed_train = True, # sử dụng bf16 cho train
    mixed_eval = False, # sử dùng bf16 cho eval
    # load_checkpoint = True, # load checkpoint cũ ko
    # load_optimizer = True, # load optimizer cũ ko (sau khi load checkpoint)
    # checkpoint_path = "checkpoint/i256_e256_h256_model/2024-12-12_20-37-10", # folder path của checkpoint
    save_path="checkpoint/i256_e64_h_256_model", # folder path lưu mô hình
    optimizer_type=torch.optim.Adam, # loại optimizer sử dụng
    device=device, # cpu hay cuda
    metadata_extra={ # tham số lưu thêm vào metada khi lưu mô hình
        "batch_size" : BATCH_SIZE,
        "dataset_name" : "Flickr8k",
        "use_hidden" : decoder.use_hidden
    },
    log_metric=True # in các metrics ra cốnle
)

Start train


  0%|          | 0/50 [00:00<?, ?it/s]

Train loss : 6.0202 | Test loss : 5.7914 | Train time : 179.92 s | Lr : 0.00100000
{'loss': 5.791400951675222, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


  2%|▏         | 1/50 [03:40<3:00:17, 220.77s/it]

Train loss : 5.8250 | Test loss : 5.7859 | Train time : 183.47 s | Lr : 0.00110000
{'loss': 5.785941063603269, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


  4%|▍         | 2/50 [07:25<2:58:26, 223.05s/it]

Train loss : 5.8168 | Test loss : 5.7899 | Train time : 180.20 s | Lr : 0.00121000
{'loss': 5.789854858495012, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


  6%|▌         | 3/50 [11:06<2:54:04, 222.23s/it]

Train loss : 5.8131 | Test loss : 5.7939 | Train time : 181.04 s | Lr : 0.00133100
{'loss': 5.793930047675024, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


  8%|▊         | 4/50 [14:48<2:50:12, 222.01s/it]

Train loss : 5.8134 | Test loss : 5.7929 | Train time : 181.24 s | Lr : 0.00146410
{'loss': 5.792937876302985, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 10%|█         | 5/50 [18:30<2:46:29, 221.99s/it]

Train loss : 5.8141 | Test loss : 5.7958 | Train time : 182.69 s | Lr : 0.00161051
{'loss': 5.7957523985754085, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 12%|█▏        | 6/50 [22:14<2:43:13, 222.58s/it]

Train loss : 5.8177 | Test loss : 5.7992 | Train time : 182.63 s | Lr : 0.00177156
{'loss': 5.799248073674455, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 14%|█▍        | 7/50 [25:57<2:39:46, 222.93s/it]

Train loss : 5.8359 | Test loss : 5.8074 | Train time : 182.71 s | Lr : 0.00194872
{'loss': 5.807439435886431, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 16%|█▌        | 8/50 [29:41<2:36:18, 223.29s/it]

Train loss : 5.8532 | Test loss : 5.8079 | Train time : 178.57 s | Lr : 0.00214359
{'loss': 5.8078729170787184, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 18%|█▊        | 9/50 [33:21<2:31:44, 222.06s/it]

Train loss : 5.8512 | Test loss : 5.8141 | Train time : 179.26 s | Lr : 0.00235795
{'loss': 5.814094241661362, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 20%|██        | 10/50 [37:01<2:27:45, 221.63s/it]

Train loss : 5.8498 | Test loss : 5.8141 | Train time : 179.61 s | Lr : 0.00250000
{'loss': 5.814090179491647, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 22%|██▏       | 11/50 [40:42<2:23:53, 221.38s/it]

Train loss : 5.8419 | Test loss : 5.8093 | Train time : 182.68 s | Lr : 0.00237500
{'loss': 5.809280419651466, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 24%|██▍       | 12/50 [44:26<2:20:42, 222.16s/it]

Train loss : 5.8332 | Test loss : 5.8041 | Train time : 173.71 s | Lr : 0.00225625
{'loss': 5.804107110711593, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 26%|██▌       | 13/50 [48:01<2:15:39, 220.00s/it]

Train loss : 5.8247 | Test loss : 5.7977 | Train time : 181.94 s | Lr : 0.00214344
{'loss': 5.797715398329723, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 28%|██▊       | 14/50 [51:44<2:12:30, 220.84s/it]

Train loss : 5.8163 | Test loss : 5.7986 | Train time : 176.79 s | Lr : 0.00203627
{'loss': 5.798568870447859, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 30%|███       | 15/50 [55:22<2:08:18, 219.97s/it]

Train loss : 5.8079 | Test loss : 5.7923 | Train time : 182.48 s | Lr : 0.00193445
{'loss': 5.792282955555976, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 32%|███▏      | 16/50 [59:05<2:05:16, 221.08s/it]

Train loss : 5.7996 | Test loss : 5.7897 | Train time : 182.85 s | Lr : 0.00183773
{'loss': 5.789684929425204, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 34%|███▍      | 17/50 [1:02:49<2:02:04, 221.95s/it]

Train loss : 5.7953 | Test loss : 5.7863 | Train time : 176.28 s | Lr : 0.00174584
{'loss': 5.786252148543732, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 36%|███▌      | 18/50 [1:06:27<1:57:38, 220.57s/it]

Train loss : 5.7895 | Test loss : 5.7788 | Train time : 177.38 s | Lr : 0.00165855
{'loss': 5.778790298896499, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 38%|███▊      | 19/50 [1:10:06<1:53:43, 220.10s/it]

Train loss : 5.7837 | Test loss : 5.7777 | Train time : 177.40 s | Lr : 0.00157562
{'loss': 5.777704389789436, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 40%|████      | 20/50 [1:13:44<1:49:46, 219.56s/it]

Train loss : 5.7787 | Test loss : 5.7770 | Train time : 181.05 s | Lr : 0.00149684
{'loss': 5.777035055281241, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 42%|████▏     | 21/50 [1:17:27<1:46:34, 220.50s/it]

Train loss : 5.7726 | Test loss : 5.7775 | Train time : 181.46 s | Lr : 0.00142200
{'loss': 5.777491122861452, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 44%|████▍     | 22/50 [1:21:09<1:43:10, 221.10s/it]

Train loss : 5.7668 | Test loss : 5.7758 | Train time : 177.17 s | Lr : 0.00135090
{'loss': 5.775846674472471, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 46%|████▌     | 23/50 [1:24:48<1:39:09, 220.34s/it]

Train loss : 5.7632 | Test loss : 5.7716 | Train time : 176.40 s | Lr : 0.00128336
{'loss': 5.77159536941142, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 48%|████▊     | 24/50 [1:28:25<1:35:06, 219.49s/it]

Train loss : 5.7599 | Test loss : 5.7727 | Train time : 185.54 s | Lr : 0.00121919
{'loss': 5.772661613512643, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 50%|█████     | 25/50 [1:32:12<1:32:18, 221.54s/it]

Train loss : 5.7553 | Test loss : 5.7641 | Train time : 175.07 s | Lr : 0.00115823
{'loss': 5.764060177380526, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 52%|█████▏    | 26/50 [1:35:48<1:27:56, 219.86s/it]

Train loss : 5.7527 | Test loss : 5.7675 | Train time : 182.06 s | Lr : 0.00110032
{'loss': 5.767464016057268, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 54%|█████▍    | 27/50 [1:39:31<1:24:41, 220.94s/it]

Train loss : 5.7477 | Test loss : 5.7649 | Train time : 177.24 s | Lr : 0.00104530
{'loss': 5.764890531950359, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 56%|█████▌    | 28/50 [1:43:09<1:20:44, 220.19s/it]

Train loss : 5.7453 | Test loss : 5.7638 | Train time : 180.47 s | Lr : 0.00099304
{'loss': 5.7637914223007005, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 58%|█████▊    | 29/50 [1:46:51<1:17:15, 220.72s/it]

Train loss : 5.7426 | Test loss : 5.7624 | Train time : 181.73 s | Lr : 0.00094338
{'loss': 5.762442763847641, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 60%|██████    | 30/50 [1:50:34<1:13:47, 221.39s/it]

Train loss : 5.7411 | Test loss : 5.7611 | Train time : 179.44 s | Lr : 0.00089621
{'loss': 5.7610932301871385, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 62%|██████▏   | 31/50 [1:54:15<1:10:01, 221.13s/it]

Train loss : 5.7362 | Test loss : 5.7602 | Train time : 181.66 s | Lr : 0.00085140
{'loss': 5.760174491737462, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 64%|██████▍   | 32/50 [1:57:58<1:06:30, 221.72s/it]

Train loss : 5.7338 | Test loss : 5.7598 | Train time : 183.57 s | Lr : 0.00080883
{'loss': 5.759840965270996, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 66%|██████▌   | 33/50 [2:01:43<1:03:05, 222.67s/it]

Train loss : 5.7314 | Test loss : 5.7582 | Train time : 179.75 s | Lr : 0.00076839
{'loss': 5.7581550199774245, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 68%|██████▊   | 34/50 [2:05:24<59:13, 222.11s/it]  

Train loss : 5.7301 | Test loss : 5.7552 | Train time : 174.99 s | Lr : 0.00072997
{'loss': 5.755234332024297, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 70%|███████   | 35/50 [2:08:59<55:02, 220.17s/it]

Train loss : 5.7267 | Test loss : 5.7546 | Train time : 171.95 s | Lr : 0.00069347
{'loss': 5.754575747477857, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 72%|███████▏  | 36/50 [2:12:33<50:53, 218.12s/it]

Train loss : 5.7239 | Test loss : 5.7539 | Train time : 180.10 s | Lr : 0.00065880
{'loss': 5.753933574579939, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 74%|███████▍  | 37/50 [2:16:14<47:28, 219.11s/it]

Train loss : 5.7225 | Test loss : 5.7526 | Train time : 181.28 s | Lr : 0.00062586
{'loss': 5.752578964716272, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 76%|███████▌  | 38/50 [2:19:57<44:01, 220.13s/it]

Train loss : 5.7219 | Test loss : 5.7522 | Train time : 176.95 s | Lr : 0.00059457
{'loss': 5.752185610276234, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 78%|███████▊  | 39/50 [2:23:35<40:15, 219.61s/it]

Train loss : 5.7194 | Test loss : 5.7497 | Train time : 176.55 s | Lr : 0.00056484
{'loss': 5.749669914004169, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 80%|████████  | 40/50 [2:27:13<36:30, 219.05s/it]

Train loss : 5.7171 | Test loss : 5.7512 | Train time : 181.01 s | Lr : 0.00053660
{'loss': 5.751237953765483, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 82%|████████▏ | 41/50 [2:30:55<32:59, 219.90s/it]

Train loss : 5.7160 | Test loss : 5.7511 | Train time : 179.57 s | Lr : 0.00050977
{'loss': 5.751052325284934, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 84%|████████▍ | 42/50 [2:34:36<29:21, 220.19s/it]

Train loss : 5.7147 | Test loss : 5.7492 | Train time : 182.27 s | Lr : 0.00048428
{'loss': 5.749211655387396, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 86%|████████▌ | 43/50 [2:38:19<25:48, 221.26s/it]

Train loss : 5.7131 | Test loss : 5.7484 | Train time : 177.67 s | Lr : 0.00046006
{'loss': 5.748361110687256, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 88%|████████▊ | 44/50 [2:41:58<22:03, 220.53s/it]

Train loss : 5.7113 | Test loss : 5.7472 | Train time : 184.39 s | Lr : 0.00043706
{'loss': 5.74718320822414, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 90%|█████████ | 45/50 [2:45:44<18:30, 222.03s/it]

Train loss : 5.7092 | Test loss : 5.7466 | Train time : 181.91 s | Lr : 0.00041521
{'loss': 5.74658729456648, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 92%|█████████▏| 46/50 [2:49:27<14:49, 222.32s/it]

Train loss : 5.7085 | Test loss : 5.7463 | Train time : 179.22 s | Lr : 0.00039445
{'loss': 5.746260105809079, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 94%|█████████▍| 47/50 [2:53:07<11:05, 221.74s/it]

Train loss : 5.7060 | Test loss : 5.7457 | Train time : 179.12 s | Lr : 0.00037473
{'loss': 5.745688679852063, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 96%|█████████▌| 48/50 [2:56:48<07:22, 221.42s/it]

Train loss : 5.7062 | Test loss : 5.7456 | Train time : 184.05 s | Lr : 0.00035599
{'loss': 5.745610496665858, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


100%|██████████| 50/50 [3:04:16<00:00, 221.13s/it]

Train loss : 5.7049 | Test loss : 5.7458 | Train time : 182.40 s | Lr : 0.00033819
{'loss': 5.745810194860531, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}





Complete


In [8]:
import util
image_size = 64
embed_size = 256
hidden_size = 256
encoder = EncoderCNN(
    output_size=image_size
)
decoder = DecoderRNN(
    embed_size=embed_size,
    vocab_size=tokenizer.vocab_size,
    hidden_size=hidden_size,
    input_size=image_size,
    num_layers=1
)
decoder.use_hidden = True
image_to_text_model = ImageToTextModel(
    encoder=encoder,
    decoder=decoder
)
loss_func= nn.CrossEntropyLoss(ignore_index=0)
model_path, train_log = util.train_eval(
    trainloader=trainloader, # train data loader
    testloader=testloader, # test data loader
    model=image_to_text_model, # model
    train_func=train, # hàm train
    test_func=test, # hàm test
    lossf=loss_func, # hàm loss
    end_epoch=50, # số epoch kết thúc train (tăng lên khi train từ checkpoint nhá, ko phải số epoch nó chạy code mà số epoch nó sẽ dừng)
    lr=2.5e-3, # lr sau khi warmup
    gamma=0.95, # lr * gamma ^ epoch sau warmup
    log_step=1, # in ra console mỗi k epochs
    warmup_nepochs=10, # số epoch chạy warmup scheduler
    warmup_lr=1e-3, # lr ban đầu của warmup
    warmup_gamma=1.1, # lr * gamma ^ epoch khi chạy warmup
    save_weight=True, # lưu trọng số mô hình
    save_full=True, # lưu optimizer, scheduler để train
    save_each=1, # luu mỗi k epoch
    mixed_train = True, # sử dụng bf16 cho train
    mixed_eval = False, # sử dùng bf16 cho eval
    # load_checkpoint = True, # load checkpoint cũ ko
    # load_optimizer = True, # load optimizer cũ ko (sau khi load checkpoint)
    # checkpoint_path = "checkpoint/i256_e256_h256_model/2024-12-12_20-37-10", # folder path của checkpoint
    save_path="checkpoint/i256_e64_h_256_model", # folder path lưu mô hình
    optimizer_type=torch.optim.Adam, # loại optimizer sử dụng
    device=device, # cpu hay cuda
    metadata_extra={ # tham số lưu thêm vào metada khi lưu mô hình
        "batch_size" : BATCH_SIZE,
        "dataset_name" : "Flickr8k",
        "use_hidden" : decoder.use_hidden
    },
    log_metric=True # in các metrics ra cốnle
)

Start train


  0%|          | 0/50 [00:00<?, ?it/s]

Train loss : 5.3344 | Test loss : 6.4149 | Train time : 197.46 s | Lr : 0.00100000
{'loss': 6.414938926696777, 'accuracy': 0.615, 'bleu_1': 0.09072173913043691, 'bleu_2': 0.04467780976477441, 'bleu_3': 0.006383117533577904, 'bleu_4': 0.0023118513968456464}


  2%|▏         | 1/50 [03:59<3:15:40, 239.61s/it]

Train loss : 4.1432 | Test loss : 7.0560 | Train time : 190.65 s | Lr : 0.00110000
{'loss': 7.056033279322371, 'accuracy': 0.615, 'bleu_1': 0.10394347826087183, 'bleu_2': 0.057783027093006444, 'bleu_3': 0.007852922847774526, 'bleu_4': 0.002917439640644708}


  4%|▍         | 2/50 [07:52<3:08:19, 235.41s/it]

Train loss : 3.8426 | Test loss : 7.3638 | Train time : 190.13 s | Lr : 0.00121000
{'loss': 7.363771402383152, 'accuracy': 0.615, 'bleu_1': 0.10083478260869809, 'bleu_2': 0.054917115480425556, 'bleu_3': 0.00988092663298743, 'bleu_4': 0.0037101771280499156}


  6%|▌         | 3/50 [11:44<3:03:19, 234.03s/it]

Train loss : 3.6611 | Test loss : 7.5418 | Train time : 192.63 s | Lr : 0.00133100
{'loss': 7.541764524918568, 'accuracy': 0.615, 'bleu_1': 0.10319130434782899, 'bleu_2': 0.057605445053012724, 'bleu_3': 0.012119100645152979, 'bleu_4': 0.004740382203341209}


  8%|▊         | 4/50 [15:39<2:59:39, 234.34s/it]

Train loss : 3.5326 | Test loss : 7.7633 | Train time : 195.06 s | Lr : 0.00146410
{'loss': 7.763250616532337, 'accuracy': 0.615, 'bleu_1': 0.10887826086956759, 'bleu_2': 0.06204949489422815, 'bleu_3': 0.012827980695604438, 'bleu_4': 0.004975062802016007}


 10%|█         | 5/50 [19:36<2:56:34, 235.43s/it]

Train loss : 3.4318 | Test loss : 8.0229 | Train time : 189.68 s | Lr : 0.00161051
{'loss': 8.022933350333684, 'accuracy': 0.615, 'bleu_1': 0.10641739130435063, 'bleu_2': 0.059724992572867903, 'bleu_3': 0.013085562616249106, 'bleu_4': 0.005209117025344725}


 12%|█▏        | 6/50 [23:28<2:51:46, 234.23s/it]

Train loss : 3.3511 | Test loss : 8.2208 | Train time : 196.30 s | Lr : 0.00177156
{'loss': 8.220811288568038, 'accuracy': 0.615, 'bleu_1': 0.10706956521739384, 'bleu_2': 0.060306561176969475, 'bleu_3': 0.015043979168702216, 'bleu_4': 0.006195865108854003}


 14%|█▍        | 7/50 [27:27<2:49:01, 235.84s/it]

Train loss : 3.2846 | Test loss : 8.3092 | Train time : 195.29 s | Lr : 0.00194872
{'loss': 8.309247028978564, 'accuracy': 0.615, 'bleu_1': 0.11185217391304593, 'bleu_2': 0.06418769141004406, 'bleu_3': 0.014671560343281518, 'bleu_4': 0.005640309209773942}


 16%|█▌        | 8/50 [31:25<2:45:25, 236.32s/it]

Train loss : 3.2295 | Test loss : 8.5661 | Train time : 199.57 s | Lr : 0.00214359
{'loss': 8.566108788116068, 'accuracy': 0.615, 'bleu_1': 0.1085608695652199, 'bleu_2': 0.06237643898155261, 'bleu_3': 0.014143162439143384, 'bleu_4': 0.005851482352159896}


 18%|█▊        | 9/50 [35:26<2:42:37, 238.00s/it]

Train loss : 3.1854 | Test loss : 8.6163 | Train time : 190.61 s | Lr : 0.00235795
{'loss': 8.616333738158021, 'accuracy': 0.615, 'bleu_1': 0.10752608695652441, 'bleu_2': 0.061031282183684074, 'bleu_3': 0.014828241881643686, 'bleu_4': 0.005832323121224771}


 20%|██        | 10/50 [39:19<2:37:36, 236.42s/it]

Train loss : 3.1476 | Test loss : 8.8970 | Train time : 192.96 s | Lr : 0.00250000
{'loss': 8.896986611281768, 'accuracy': 0.615, 'bleu_1': 0.11135217391304611, 'bleu_2': 0.06416212433573568, 'bleu_3': 0.015571474223750789, 'bleu_4': 0.0069942188554717516}


 22%|██▏       | 11/50 [43:14<2:33:27, 236.09s/it]

Train loss : 3.0938 | Test loss : 8.9782 | Train time : 193.11 s | Lr : 0.00237500
{'loss': 8.978172326389748, 'accuracy': 0.615, 'bleu_1': 0.10871304347826329, 'bleu_2': 0.0623096593180097, 'bleu_3': 0.013530700768168364, 'bleu_4': 0.005183517111653175}


 24%|██▍       | 12/50 [47:10<2:29:21, 235.83s/it]

Train loss : 3.0465 | Test loss : 9.0485 | Train time : 192.58 s | Lr : 0.00225625
{'loss': 9.048511227474936, 'accuracy': 0.615, 'bleu_1': 0.11276521739130721, 'bleu_2': 0.06511750199165638, 'bleu_3': 0.016428160727629745, 'bleu_4': 0.0067266287963723475}


 26%|██▌       | 13/50 [51:05<2:25:14, 235.53s/it]

Train loss : 3.0047 | Test loss : 9.3153 | Train time : 193.77 s | Lr : 0.00214344
{'loss': 9.315316405477404, 'accuracy': 0.615, 'bleu_1': 0.10895652173913291, 'bleu_2': 0.06284820149382414, 'bleu_3': 0.015532942436387356, 'bleu_4': 0.006167771360976898}


 28%|██▊       | 14/50 [55:00<2:21:20, 235.58s/it]

Train loss : 2.9648 | Test loss : 9.4514 | Train time : 196.64 s | Lr : 0.00203627
{'loss': 9.451397594017319, 'accuracy': 0.615, 'bleu_1': 0.11036956521739377, 'bleu_2': 0.06301139272156775, 'bleu_3': 0.016249675234376478, 'bleu_4': 0.006640786771993962}


 30%|███       | 15/50 [58:59<2:18:01, 236.61s/it]

Train loss : 2.9302 | Test loss : 9.5597 | Train time : 194.51 s | Lr : 0.00193445
{'loss': 9.55966349493099, 'accuracy': 0.615, 'bleu_1': 0.11275652173913331, 'bleu_2': 0.06486702405554273, 'bleu_3': 0.015897333440537998, 'bleu_4': 0.006150794050808473}


 32%|███▏      | 16/50 [1:02:56<2:14:07, 236.68s/it]

Train loss : 2.9004 | Test loss : 9.5888 | Train time : 196.63 s | Lr : 0.00183773
{'loss': 9.588816437540174, 'accuracy': 0.615, 'bleu_1': 0.11227391304348064, 'bleu_2': 0.06498622413315439, 'bleu_3': 0.016589368299846967, 'bleu_4': 0.006939448765514963}


 34%|███▍      | 17/50 [1:06:55<2:10:35, 237.45s/it]

Train loss : 2.8687 | Test loss : 9.7246 | Train time : 194.46 s | Lr : 0.00174584
{'loss': 9.724602494058729, 'accuracy': 0.6152, 'bleu_1': 0.11130869565217671, 'bleu_2': 0.06427034993580454, 'bleu_3': 0.0156134739488669, 'bleu_4': 0.0063131665522568586}


 36%|███▌      | 18/50 [1:10:53<2:06:44, 237.65s/it]

Train loss : 2.8389 | Test loss : 9.8745 | Train time : 201.75 s | Lr : 0.00165855
{'loss': 9.874457117877428, 'accuracy': 0.615, 'bleu_1': 0.1101826086956546, 'bleu_2': 0.06322214394053767, 'bleu_3': 0.01582367344203886, 'bleu_4': 0.006320100924358943}


 38%|███▊      | 19/50 [1:14:59<2:04:00, 240.02s/it]

Train loss : 2.8133 | Test loss : 9.8694 | Train time : 200.49 s | Lr : 0.00157562
{'loss': 9.86940793146061, 'accuracy': 0.615, 'bleu_1': 0.11421739130435028, 'bleu_2': 0.0651545224999593, 'bleu_3': 0.015800650647151963, 'bleu_4': 0.006059771638494427}


 40%|████      | 20/50 [1:19:04<2:00:45, 241.50s/it]

Train loss : 2.7887 | Test loss : 9.9148 | Train time : 202.87 s | Lr : 0.00149684
{'loss': 9.914836086804353, 'accuracy': 0.615, 'bleu_1': 0.11342608695652412, 'bleu_2': 0.06536577829384056, 'bleu_3': 0.01605963745182739, 'bleu_4': 0.006635664045891721}


 42%|████▏     | 21/50 [1:23:33<2:00:44, 249.81s/it]

Train loss : 2.7664 | Test loss : 10.1050 | Train time : 273.61 s | Lr : 0.00142200
{'loss': 10.105016370362874, 'accuracy': 0.615, 'bleu_1': 0.11133913043478476, 'bleu_2': 0.06424519709067084, 'bleu_3': 0.016278943110170512, 'bleu_4': 0.006863290573398875}


 44%|████▍     | 22/50 [1:29:14<2:09:18, 277.09s/it]

Train loss : 2.7437 | Test loss : 10.1416 | Train time : 243.93 s | Lr : 0.00135090
{'loss': 10.141595417940163, 'accuracy': 0.615, 'bleu_1': 0.11190869565217679, 'bleu_2': 0.0646856717696088, 'bleu_3': 0.01597238579463997, 'bleu_4': 0.0063367949543177106}


 46%|████▌     | 23/50 [1:34:03<2:06:20, 280.77s/it]

Train loss : 2.7245 | Test loss : 10.2170 | Train time : 278.20 s | Lr : 0.00128336
{'loss': 10.217025467112094, 'accuracy': 0.615, 'bleu_1': 0.11271739130435074, 'bleu_2': 0.06501646443205038, 'bleu_3': 0.016847461472185677, 'bleu_4': 0.007154241977977715}


 48%|████▊     | 24/50 [1:39:52<2:10:34, 301.33s/it]

Train loss : 2.7040 | Test loss : 10.3225 | Train time : 279.95 s | Lr : 0.00121919
{'loss': 10.322463615031182, 'accuracy': 0.615, 'bleu_1': 0.11355217391304581, 'bleu_2': 0.06580413615210973, 'bleu_3': 0.015761661488887014, 'bleu_4': 0.006346785345838383}


 50%|█████     | 25/50 [1:45:41<2:11:23, 315.35s/it]

Train loss : 2.6854 | Test loss : 10.3049 | Train time : 228.70 s | Lr : 0.00115823
{'loss': 10.304898612106903, 'accuracy': 0.615, 'bleu_1': 0.11389130434782878, 'bleu_2': 0.06578393463851202, 'bleu_3': 0.01609225522783975, 'bleu_4': 0.006464701264231019}


 52%|█████▏    | 26/50 [1:50:13<2:00:56, 302.35s/it]

Train loss : 2.6680 | Test loss : 10.3951 | Train time : 202.45 s | Lr : 0.00110032
{'loss': 10.395073021514506, 'accuracy': 0.615, 'bleu_1': 0.11293913043478529, 'bleu_2': 0.0655056357184131, 'bleu_3': 0.01678869694964415, 'bleu_4': 0.006620419448199945}


 54%|█████▍    | 27/50 [1:54:18<1:49:19, 285.21s/it]

Train loss : 2.6502 | Test loss : 10.4826 | Train time : 199.97 s | Lr : 0.00104530
{'loss': 10.482555956780155, 'accuracy': 0.6152, 'bleu_1': 0.11380434782608961, 'bleu_2': 0.06571490005619059, 'bleu_3': 0.016987616004978025, 'bleu_4': 0.006629223840833713}


 56%|█████▌    | 28/50 [1:58:21<1:40:00, 272.74s/it]

Train loss : 2.6356 | Test loss : 10.5378 | Train time : 190.72 s | Lr : 0.00099304
{'loss': 10.537840420686745, 'accuracy': 0.6152, 'bleu_1': 0.11331304347826353, 'bleu_2': 0.0656022625455959, 'bleu_3': 0.01631950690695666, 'bleu_4': 0.006451429075378772}


 58%|█████▊    | 29/50 [2:02:15<1:31:19, 260.91s/it]

Train loss : 2.6202 | Test loss : 10.5616 | Train time : 186.55 s | Lr : 0.00094338
{'loss': 10.56162091750133, 'accuracy': 0.615, 'bleu_1': 0.11330434782608963, 'bleu_2': 0.06565886365000664, 'bleu_3': 0.01692638481465299, 'bleu_4': 0.0066546058024266885}


 60%|██████    | 30/50 [2:06:06<1:24:02, 252.14s/it]

Train loss : 2.6073 | Test loss : 10.6789 | Train time : 200.75 s | Lr : 0.00089621
{'loss': 10.678919743887986, 'accuracy': 0.615, 'bleu_1': 0.11413043478261135, 'bleu_2': 0.06632167004043808, 'bleu_3': 0.017684434493762304, 'bleu_4': 0.006943756744529968}


 62%|██████▏   | 31/50 [2:10:17<1:19:43, 251.74s/it]

Train loss : 2.5931 | Test loss : 10.7995 | Train time : 224.18 s | Lr : 0.00085140
{'loss': 10.799514649789545, 'accuracy': 0.6152, 'bleu_1': 0.1133956521739155, 'bleu_2': 0.0657120472248753, 'bleu_3': 0.01744929629863652, 'bleu_4': 0.00709100510520327}


 64%|██████▍   | 32/50 [2:14:50<1:17:25, 258.11s/it]

Train loss : 2.5810 | Test loss : 10.7619 | Train time : 216.66 s | Lr : 0.00080883
{'loss': 10.761868138856526, 'accuracy': 0.6152, 'bleu_1': 0.11382608695652416, 'bleu_2': 0.06566045925640542, 'bleu_3': 0.017183361210495386, 'bleu_4': 0.006903037306728821}


 66%|██████▌   | 33/50 [2:19:13<1:13:31, 259.50s/it]

Train loss : 2.5691 | Test loss : 10.8053 | Train time : 214.82 s | Lr : 0.00076839
{'loss': 10.80529694617549, 'accuracy': 0.6152, 'bleu_1': 0.11361304347826347, 'bleu_2': 0.06572098979101501, 'bleu_3': 0.01690921412447092, 'bleu_4': 0.006782871998994756}


 68%|██████▊   | 34/50 [2:23:37<1:09:33, 260.84s/it]

Train loss : 2.5573 | Test loss : 10.8365 | Train time : 206.41 s | Lr : 0.00072997
{'loss': 10.83652013464819, 'accuracy': 0.6152, 'bleu_1': 0.1158739130434809, 'bleu_2': 0.06710192418706255, 'bleu_3': 0.017531473613070312, 'bleu_4': 0.007025659755293735}


 70%|███████   | 35/50 [2:27:50<1:04:35, 258.39s/it]

Train loss : 2.5454 | Test loss : 10.8984 | Train time : 213.39 s | Lr : 0.00069347
{'loss': 10.898436872265007, 'accuracy': 0.6152, 'bleu_1': 0.11412173913043751, 'bleu_2': 0.06591120820870335, 'bleu_3': 0.01727510666317028, 'bleu_4': 0.007164424279645997}


 72%|███████▏  | 36/50 [2:32:13<1:00:37, 259.83s/it]

Train loss : 2.5361 | Test loss : 10.9602 | Train time : 219.68 s | Lr : 0.00065880
{'loss': 10.960202917268004, 'accuracy': 0.6152, 'bleu_1': 0.11426521739130714, 'bleu_2': 0.06618928674145291, 'bleu_3': 0.016999952198874967, 'bleu_4': 0.006750141130317421}


 74%|███████▍  | 37/50 [2:36:42<56:55, 262.72s/it]  

Train loss : 2.5243 | Test loss : 10.8977 | Train time : 209.10 s | Lr : 0.00062586
{'loss': 10.897735221476495, 'accuracy': 0.615, 'bleu_1': 0.11418695652174209, 'bleu_2': 0.06600906464816252, 'bleu_3': 0.016553410229496416, 'bleu_4': 0.006653074071172158}


 76%|███████▌  | 38/50 [2:40:57<52:03, 260.30s/it]

Train loss : 2.5170 | Test loss : 11.1414 | Train time : 217.76 s | Lr : 0.00059457
{'loss': 11.141402316998832, 'accuracy': 0.6152, 'bleu_1': 0.11406521739130687, 'bleu_2': 0.06593207381164472, 'bleu_3': 0.017035992941278468, 'bleu_4': 0.006987865799700198}


 78%|███████▊  | 39/50 [2:45:25<48:07, 262.51s/it]

Train loss : 2.5089 | Test loss : 11.0614 | Train time : 215.74 s | Lr : 0.00056484
{'loss': 11.061372889748103, 'accuracy': 0.6152, 'bleu_1': 0.11397391304348108, 'bleu_2': 0.06620202851373166, 'bleu_3': 0.017175629720091855, 'bleu_4': 0.006753703847890171}


 80%|████████  | 40/50 [2:49:47<43:43, 262.37s/it]

Train loss : 2.4975 | Test loss : 11.1530 | Train time : 208.25 s | Lr : 0.00053660
{'loss': 11.153032145922698, 'accuracy': 0.6152, 'bleu_1': 0.11307826086956808, 'bleu_2': 0.06525705639192667, 'bleu_3': 0.016531951682852132, 'bleu_4': 0.006678297629635735}


 82%|████████▏ | 41/50 [2:54:07<39:15, 261.72s/it]

Train loss : 2.4891 | Test loss : 11.1107 | Train time : 216.07 s | Lr : 0.00050977
{'loss': 11.110674664944034, 'accuracy': 0.6152, 'bleu_1': 0.1156217391304372, 'bleu_2': 0.06695646395523594, 'bleu_3': 0.017460988118229222, 'bleu_4': 0.007207228194910621}


 84%|████████▍ | 42/50 [2:58:30<34:57, 262.16s/it]

Train loss : 2.4833 | Test loss : 11.2142 | Train time : 210.95 s | Lr : 0.00048428
{'loss': 11.214244576949108, 'accuracy': 0.6152, 'bleu_1': 0.1143913043478288, 'bleu_2': 0.06646912177734571, 'bleu_3': 0.017867940702911946, 'bleu_4': 0.007170603224365993}


 86%|████████▌ | 43/50 [3:02:47<30:25, 260.76s/it]

Train loss : 2.4806 | Test loss : 11.2120 | Train time : 218.21 s | Lr : 0.00046006
{'loss': 11.212002271338354, 'accuracy': 0.6152, 'bleu_1': 0.11400000000000289, 'bleu_2': 0.0661926059552582, 'bleu_3': 0.017511618147793347, 'bleu_4': 0.006872162192866159}


 88%|████████▊ | 44/50 [3:07:16<26:17, 262.95s/it]

Train loss : 2.4705 | Test loss : 11.2596 | Train time : 217.41 s | Lr : 0.00043706
{'loss': 11.2596305774737, 'accuracy': 0.615, 'bleu_1': 0.1135130434782635, 'bleu_2': 0.06584752216044179, 'bleu_3': 0.017181350312123762, 'bleu_4': 0.00702866405729592}


 90%|█████████ | 45/50 [3:11:44<22:02, 264.60s/it]

Train loss : 2.4636 | Test loss : 11.1777 | Train time : 213.02 s | Lr : 0.00041521
{'loss': 11.177659722823131, 'accuracy': 0.6152, 'bleu_1': 0.11460434782608978, 'bleu_2': 0.06641167671578561, 'bleu_3': 0.017162078059098612, 'bleu_4': 0.006701933636893738}


 92%|█████████▏| 46/50 [3:16:04<17:33, 263.36s/it]

Train loss : 2.4566 | Test loss : 11.3075 | Train time : 209.69 s | Lr : 0.00039445
{'loss': 11.307484892350208, 'accuracy': 0.6152, 'bleu_1': 0.11570000000000281, 'bleu_2': 0.06730082700349842, 'bleu_3': 0.018131077704571858, 'bleu_4': 0.007354531982573151}


 94%|█████████▍| 47/50 [3:20:24<13:06, 262.27s/it]

Train loss : 2.4509 | Test loss : 11.3232 | Train time : 211.86 s | Lr : 0.00037473
{'loss': 11.323190085495575, 'accuracy': 0.6152, 'bleu_1': 0.11563478260869851, 'bleu_2': 0.06708574537050492, 'bleu_3': 0.017112695450623213, 'bleu_4': 0.00699461667340886}


 96%|█████████▌| 48/50 [3:24:47<08:44, 262.35s/it]

Train loss : 2.4452 | Test loss : 11.3913 | Train time : 225.83 s | Lr : 0.00035599
{'loss': 11.391299296029006, 'accuracy': 0.6152, 'bleu_1': 0.11505217391304588, 'bleu_2': 0.06702450582824769, 'bleu_3': 0.01788659174830984, 'bleu_4': 0.007419327717471376}


100%|██████████| 50/50 [3:33:48<00:00, 256.56s/it]

Train loss : 2.4400 | Test loss : 11.3976 | Train time : 219.52 s | Lr : 0.00033819
{'loss': 11.39760019809385, 'accuracy': 0.615, 'bleu_1': 0.11586956521739385, 'bleu_2': 0.06725602007077935, 'bleu_3': 0.01771069264615197, 'bleu_4': 0.007425141424463584}





Complete
