Format lại code

In [1]:
import matplotlib.pyplot as plt
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
import math
import torch.multiprocessing as mp
from transformers import BertTokenizerFast
import pickle
# Evaluate
from nltk.translate.bleu_score import sentence_bleu
# from nltk.translate.meteor_score import meteor_score
# from rogue import Rogue
# from pycocoevalcap.cider.cider import Cider
# from pycocoevalcap.spice.spice import Spice
#end
# mp.set_start_method('spawn', force=True)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
detokenize = tokenizer.convert_ids_to_tokens
batch_detokenize = tokenizer.batch_decode
BATCH_SIZE = 64 
#64 : 3.1GB VRAM b16 (3.0 Dedicated | 0.1 Shared)
device = 'cuda'
image_transforms = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Dùng cho ImageNet
])


In [2]:
def load_image(image_folder_path: str): # DataLoader trên Jupyternotebook ko xử lý đa luồng đc nên đành phải load hết vô
    result = {}
    file_counts = len(os.listdir(image_folder_path))
    progress = 0
    last_log = 0
    count = 0
    for file_name in os.listdir(image_folder_path):
        file_path = os.path.join(image_folder_path, file_name)
        if os.path.isfile(file_path):
            id = os.path.splitext(os.path.basename(file_name))[0]
            if id in result: continue
            with Image.open(file_path).convert("RGB") as img:
                image = image_transforms(img)
                result[id] = torch.Tensor(image)
            count += 1
            progress = count / file_counts * 100
            if progress - last_log >= 10:
                last_log = progress
                print(f"Image loading {progress:.2f} % ")
    if last_log != 100:
        print(f"Image loading 100 % ")
    return result
def process_data(image_dict: dict[str, torch.Tensor], processed_data_path: str):
    result: list[tuple[torch.Tensor, torch.Tensor]] = []
    with open(processed_data_path, 'rb') as file:
        processed_data = pickle.load(file)
    for image_id, caption in processed_data:
        caption = torch.tensor(caption)
        image = image_dict[image_id]
        result.append((image, caption))
    return result
def get_train_test_loader(batch_size: int, n_wokers: int = 2):
    image_path = "../Flickr8k/Flicker8k_Dataset"
    train_path = "../Data_bert/train_set_bert.pkl"
    test_path = "../Data_bert/test_set_bert.pkl"
    image_dict: dict[str, torch.Tensor] = load_image(image_path)
    train_data = process_data(image_dict, train_path)
    test_data = process_data(image_dict, test_path)
    trainloader = DataLoader(train_data, batch_size=batch_size, num_workers=n_wokers, shuffle=True)
    testloader = DataLoader(test_data, batch_size=batch_size, num_workers=n_wokers, shuffle=False)
    return trainloader, testloader
trainloader, testloader = get_train_test_loader(BATCH_SIZE, 2)

Image loading 10.01 % 
Image loading 20.02 % 
Image loading 30.03 % 
Image loading 40.04 % 
Image loading 50.06 % 
Image loading 60.07 % 
Image loading 70.08 % 
Image loading 80.09 % 
Image loading 90.10 % 
Image loading 100 % 


In [3]:
# for images, captions in trainloader:
#     for i in range(min(4, captions.shape[0])):
#         print(images[i][:,100,100])
#         print(detokenize(captions[i]))
#     break

In [4]:
class EncoderCNN(nn.Module):
    def __init__(self, output_size: int):
        super(EncoderCNN, self).__init__()
        self.inception_model = models.inception_v3(pretrained=True)
        #self.inception_model.fc = torch.nn.Identity()
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(1000, output_size)
        for name, param in self.inception_model.named_parameters():
            if "fc.weight" in name or "fc.bias" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
    def forward(self, images: torch.Tensor):
        features = self.inception_model(images) #[1, 2048]
        if isinstance(features, tuple):  # Nếu là tuple
            features = features[0] 
        features = self.relu(features)
        features = self.dropout(features)
        features = self.fc(features)
        return features
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, vocab_size, hidden_size, input_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size + input_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.use_hidden = False
    def forward(self, features: torch.Tensor, captions: torch.Tensor, hidden_state: tuple[torch.Tensor, torch.Tensor] = None): 
        # Constant : seq = 1
        # features : [bsz, img_sz]
        # captions : [bsz, seq]
        # hidden : [num_layers, bsz, hidden]
        embeddings = self.embed(captions) # [bsz, seq, embed]
        features = features.unsqueeze(1).expand(-1, embeddings.shape[1], -1) # [bsz, seq, embed]
        combined = torch.cat((features, embeddings), dim=2) # [bsz, seq, img_sz + embed]
        # hidden_state : [num_layers, seq, hid] * 2
        if self.use_hidden:
            output, hidden_state = self.lstm(combined, hidden_state)
        else:
            output, hidden_state = self.lstm(combined)
        # output : [bsz, seq, vocab_size]
        output = self.relu(output)
        output = self.dropout(output)
        output = self.linear(output)
        return output, hidden_state
class ImageToTextModel(nn.Module):
    def __init__(self, encoder: nn.Module, decoder: nn.Module):
        super(ImageToTextModel, self).__init__()
        self.encoder: EncoderCNN = encoder
        self.decoder: DecoderRNN = decoder
    def forward(self, images: torch.Tensor, captions: torch.Tensor):
        # Constant : SEQ_LENGTH = 46, seq = 1
        # images: [bsz, 3, raw_image_width, raw_image_height]
        # captions: [bsz, SEQ_LENGTH]
        bsz = images.shape[0]
        if self.decoder.use_hidden:
            hidden_state: tuple[torch.Tensor, torch.Tensor] = None
        features = self.encoder(images)
        # features: [bsz, img_sz]
        seq_predicted = []
        seq_predicted.append(torch.zeros((bsz, self.decoder.vocab_size), dtype=torch.float32).unsqueeze(1).to(device))
        # seq_predicted : [predict_length, seq, vocab]
        decoder_input = captions[:, 0].unsqueeze(1)
        # decoder_input : [bsz, seq]
        seq_length = captions.shape[1]
        for di in range(1, seq_length):
            if self.decoder.use_hidden:
                output_decoder, hidden_state = self.decoder(features, decoder_input, hidden_state)
            else:
                output_decoder, hidden_state = self.decoder(features, decoder_input)
            # ouput_decoder: [bsz, seq, vocab]
            # hidden_state: [num_layers, bsz, hidden_size] * 2
            decoder_input = captions[:, di].unsqueeze(1)
            # decoder_input : [bsz, seq]
            seq_predicted.append(output_decoder)
        return torch.cat(seq_predicted, dim=1)
    def predict(self, images: torch.Tensor, captions: torch.Tensor, predict_length: int):
        bsz = images.shape[0]
        hidden_state: tuple[torch.Tensor, torch.Tensor] = None
        features = self.encoder(images)
        seq_predicted = []
        seq_predicted.append(torch.zeros((bsz, self.decoder.vocab_size), dtype=torch.float32).unsqueeze(1).to(device))
        decoder_input = captions
        for _ in range(predict_length-1):
            output_decoder, hidden_state = self.decoder(features, decoder_input, hidden_state)
            seq_predicted.append(output_decoder)
            decoder_input = output_decoder.argmax(2)
        return torch.cat(seq_predicted, dim=1)


In [5]:
def train(model : ImageToTextModel, dataloader : DataLoader, lossf : callable, optimizer : torch.optim.Optimizer, mixed: bool, device: str):
    model.train()
    epoch_loss = 0
    count = 0
    for images, captions in dataloader:
        optimizer.zero_grad()
        images: torch.Tensor = images.to(device)
        captions: torch.Tensor = captions.to(device)
        if mixed:
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                outputs: torch.Tensor = model(images, captions)
                loss: torch.Tensor = lossf(outputs.view(-1, outputs.shape[2]), captions.view(-1))
        else:
            outputs: torch.Tensor = model(images, captions)
            loss: torch.Tensor = lossf(outputs.view(-1, outputs.shape[2]), captions.view(-1))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        count += 1 
    return epoch_loss / count
def compute_metrics(predict: torch.Tensor, caption: torch.Tensor): # Quên chưa xóa token đầu ...
    # Accuracy
    is_correct = 1 if predict[1] == caption[1] else 0
    caption: list[str] = detokenize(caption)
    predict: list[str] = detokenize(predict)
    while (caption[-1] == "[PAD]"):
        caption.pop()
    while (predict[-1] == "[PAD]"):
        predict.pop()
    # Bleu
    reference = [caption]
    # print(caption, predict)
    bleu_1 = sentence_bleu(reference, predict, weights=(1, 0, 0, 0))
    bleu_2 = sentence_bleu(reference, predict, weights=(0.5, 0.5, 0, 0))
    bleu_3 = sentence_bleu(reference, predict, weights=(0.33, 0.33, 0.33, 0))
    bleu_4 = sentence_bleu(reference, predict)
    # Rogue
    # rogue = Rogue()
    # scores = rogue.get_scores(' '.join(predicts), ' '.join(captions))
    # Meteor
    # meteor_score_ = meteor_score([' '.join(captions)], ' '.join(predicts))
    # Cider, spice
    # cider_scorer = Cider()
    # spice_scorer = Spice()
    # cider_score, _ = cider_scorer.compute_score({0 : [' '.join(caption)]}, {0 : [' '.join(predict)]})
    # spice_score, _ = spice_scorer.compute_score({0 : [' '.join(captions)]}, {0 : [' '.join(predicts)]})
    return {
        "accuracy" : is_correct,
        "bleu_1" : bleu_1,
        "bleu_2" : bleu_2,
        "bleu_3" : bleu_3,
        "bleu_4" : bleu_4,
        # "rogue-1" : scores["rogue-1"]["f"],
        # "rogue-2" : scores["rogue-2"]["f"],
        # "rogue-l" : scores["rogue-l"]["f"],
        # "meteor" : meteor_score_,
        # "cider" : cider_score,
        # "spice" : spice_score
    }
def test(model : ImageToTextModel, dataloader : DataLoader, lossf : callable, mixed: bool, device: str):
    model.eval()
    count = 0
    total_count = 0
    epoch_metrics = {'loss' : 0}
    input = torch.tensor([101]).to(device)
    for images, captions in dataloader:
        images: torch.Tensor = images.to(device)
        captions: torch.Tensor = captions.to(device)
        bsz = images.shape[0]
        inputs = input.unsqueeze(1).expand((bsz, 1))
        if mixed:
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                outputs: torch.Tensor = model.predict(images, inputs, captions.shape[1])
                loss: torch.Tensor= lossf(outputs.view(-1, outputs.shape[2]), captions.view(-1))
        else:
            outputs: torch.Tensor  = model.predict(images, inputs, captions.shape[1])
            loss: torch.Tensor= lossf(outputs.view(-1, outputs.shape[2]), captions.view(-1))
        predicts = outputs.argmax(2)
        count += 1
        total_count += bsz
        epoch_metrics['loss'] += loss.item()
        for i in range(bsz):
            metrics = compute_metrics(predicts[i], captions[i])
            for key in metrics:
                if key not in epoch_metrics:
                    epoch_metrics[key] = metrics[key]
                else:
                    epoch_metrics[key] += metrics[key]
    for key in epoch_metrics:
        if key in ['loss']:
            epoch_metrics[key] /= count
        else:
            epoch_metrics[key] /= total_count
    return epoch_metrics

In [6]:
import util
image_size = 256
embed_size = 256
hidden_size = 256
encoder = EncoderCNN(
    output_size=image_size
)
decoder = DecoderRNN(
    embed_size=embed_size,
    vocab_size=tokenizer.vocab_size,
    hidden_size=hidden_size,
    input_size=image_size,
    num_layers=1
)
decoder.use_hidden = False
image_to_text_model = ImageToTextModel(
    encoder=encoder,
    decoder=decoder
)
loss_func= nn.CrossEntropyLoss(ignore_index=0)
model_path, train_log = util.train_eval(
    trainloader=trainloader,
    testloader=testloader,
    model=image_to_text_model,
    train_func=train,
    test_func=test,
    lossf=loss_func,
    num_epochs=50,
    lr=2.5e-3,
    gamma=0.95,
    log_step=1,
    warmup_nepochs=10,
    warmup_lr=1e-3,
    warmup_gamma=1.1,
    save=True,
    mixed_train = True,
    mixed_eval = False,
    save_path="checkpoint/teset_model",
    optimizer_type=torch.optim.Adam,
    device=device,
    metadata_extra={
        "batch_size" : BATCH_SIZE,
        "dataset_name" : "Flickr8k",
        "use_hidden" : decoder.use_hidden
    },
    log_metric=True
)



Start train


The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
  2%|▏         | 1/50 [04:28<3:39:26, 268.71s/it]

Train loss : 5.9695 | Test loss : 5.7548 | Train time : 222.87 s | Lr : 0.00100000
{'loss': 5.754833100717279, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


  4%|▍         | 2/50 [08:19<3:17:08, 246.43s/it]

Train loss : 5.7842 | Test loss : 5.7598 | Train time : 190.18 s | Lr : 0.00110000
{'loss': 5.759811159930652, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


  6%|▌         | 3/50 [12:15<3:09:21, 241.73s/it]

Train loss : 5.7796 | Test loss : 5.7606 | Train time : 195.40 s | Lr : 0.00121000
{'loss': 5.760557168646704, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


  8%|▊         | 4/50 [16:21<3:06:42, 243.54s/it]

Train loss : 5.7815 | Test loss : 5.7712 | Train time : 206.02 s | Lr : 0.00133100
{'loss': 5.771235713475867, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 10%|█         | 5/50 [20:06<2:57:38, 236.85s/it]

Train loss : 5.7866 | Test loss : 5.7759 | Train time : 184.69 s | Lr : 0.00146410
{'loss': 5.775890452952325, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 12%|█▏        | 6/50 [23:51<2:50:44, 232.82s/it]

Train loss : 5.7924 | Test loss : 5.7820 | Train time : 184.53 s | Lr : 0.00161051
{'loss': 5.781957149505615, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 14%|█▍        | 7/50 [27:37<2:45:01, 230.27s/it]

Train loss : 5.8000 | Test loss : 5.7923 | Train time : 184.85 s | Lr : 0.00177156
{'loss': 5.792266869846778, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 16%|█▌        | 8/50 [31:35<2:43:00, 232.87s/it]

Train loss : 5.8101 | Test loss : 5.7932 | Train time : 198.15 s | Lr : 0.00194872
{'loss': 5.793198700192608, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 18%|█▊        | 9/50 [35:33<2:40:10, 234.40s/it]

Train loss : 5.8253 | Test loss : 5.8051 | Train time : 197.34 s | Lr : 0.00214359
{'loss': 5.805065269711651, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 20%|██        | 10/50 [39:29<2:36:45, 235.14s/it]

Train loss : 5.8331 | Test loss : 5.8110 | Train time : 196.12 s | Lr : 0.00235795
{'loss': 5.8110112902484365, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 22%|██▏       | 11/50 [43:15<2:30:59, 232.30s/it]

Train loss : 5.8384 | Test loss : 5.8177 | Train time : 185.37 s | Lr : 0.00250000
{'loss': 5.817700687843033, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 24%|██▍       | 12/50 [47:06<2:26:53, 231.93s/it]

Train loss : 5.8352 | Test loss : 5.8172 | Train time : 190.34 s | Lr : 0.00237500
{'loss': 5.817217102533655, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 26%|██▌       | 13/50 [51:03<2:23:53, 233.34s/it]

Train loss : 5.8314 | Test loss : 5.8146 | Train time : 196.49 s | Lr : 0.00225625
{'loss': 5.81463390060618, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 28%|██▊       | 14/50 [54:58<2:20:14, 233.72s/it]

Train loss : 5.8272 | Test loss : 5.8201 | Train time : 194.60 s | Lr : 0.00214344
{'loss': 5.8201395288298405, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 30%|███       | 15/50 [58:42<2:14:46, 231.04s/it]

Train loss : 5.8224 | Test loss : 5.8105 | Train time : 184.68 s | Lr : 0.00203627
{'loss': 5.810458346258236, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 32%|███▏      | 16/50 [1:02:39<2:11:51, 232.69s/it]

Train loss : 5.8172 | Test loss : 5.8111 | Train time : 196.36 s | Lr : 0.00193445
{'loss': 5.81110730352281, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 34%|███▍      | 17/50 [1:06:32<2:08:04, 232.86s/it]

Train loss : 5.8138 | Test loss : 5.8118 | Train time : 193.09 s | Lr : 0.00183773
{'loss': 5.81178516074072, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 36%|███▌      | 18/50 [1:10:30<2:04:58, 234.34s/it]

Train loss : 5.8088 | Test loss : 5.8125 | Train time : 197.24 s | Lr : 0.00174584
{'loss': 5.812481928475296, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 38%|███▊      | 19/50 [1:14:16<1:59:42, 231.70s/it]

Train loss : 5.8059 | Test loss : 5.8074 | Train time : 184.94 s | Lr : 0.00165855
{'loss': 5.807435578937772, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 40%|████      | 20/50 [1:18:06<1:55:38, 231.28s/it]

Train loss : 5.8010 | Test loss : 5.8085 | Train time : 190.12 s | Lr : 0.00157562
{'loss': 5.808524777617635, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 42%|████▏     | 21/50 [1:21:59<1:51:59, 231.71s/it]

Train loss : 5.7972 | Test loss : 5.8097 | Train time : 192.35 s | Lr : 0.00149684
{'loss': 5.809740483006345, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 44%|████▍     | 22/50 [1:25:53<1:48:30, 232.51s/it]

Train loss : 5.7944 | Test loss : 5.8045 | Train time : 193.60 s | Lr : 0.00142200
{'loss': 5.8044575196278245, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 46%|████▌     | 23/50 [1:29:39<1:43:43, 230.51s/it]

Train loss : 5.7906 | Test loss : 5.8075 | Train time : 185.03 s | Lr : 0.00135090
{'loss': 5.8075227495990225, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 48%|████▊     | 24/50 [1:33:36<1:40:46, 232.55s/it]

Train loss : 5.7877 | Test loss : 5.8057 | Train time : 196.86 s | Lr : 0.00128336
{'loss': 5.805685858183269, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 50%|█████     | 25/50 [1:37:31<1:37:07, 233.10s/it]

Train loss : 5.7837 | Test loss : 5.8032 | Train time : 194.37 s | Lr : 0.00121919
{'loss': 5.80316784412046, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 52%|█████▏    | 26/50 [1:41:27<1:33:38, 234.09s/it]

Train loss : 5.7808 | Test loss : 5.8015 | Train time : 195.67 s | Lr : 0.00115823
{'loss': 5.80154206481161, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 54%|█████▍    | 27/50 [1:45:19<1:29:28, 233.42s/it]

Train loss : 5.7783 | Test loss : 5.7985 | Train time : 191.41 s | Lr : 0.00110032
{'loss': 5.798532673075229, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 56%|█████▌    | 28/50 [1:49:11<1:25:24, 232.94s/it]

Train loss : 5.7756 | Test loss : 5.8018 | Train time : 191.78 s | Lr : 0.00104530
{'loss': 5.801845254777353, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 58%|█████▊    | 29/50 [1:52:57<1:20:48, 230.88s/it]

Train loss : 5.7719 | Test loss : 5.7999 | Train time : 185.84 s | Lr : 0.00099304
{'loss': 5.799902668482141, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 60%|██████    | 30/50 [1:56:47<1:16:55, 230.78s/it]

Train loss : 5.7703 | Test loss : 5.7980 | Train time : 190.47 s | Lr : 0.00094338
{'loss': 5.7979735845251925, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 62%|██████▏   | 31/50 [2:00:35<1:12:48, 229.92s/it]

Train loss : 5.7670 | Test loss : 5.7966 | Train time : 187.42 s | Lr : 0.00089621
{'loss': 5.796579095381725, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 64%|██████▍   | 32/50 [2:04:20<1:08:31, 228.41s/it]

Train loss : 5.7641 | Test loss : 5.7975 | Train time : 184.63 s | Lr : 0.00085140
{'loss': 5.7974796174447745, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 66%|██████▌   | 33/50 [2:08:16<1:05:22, 230.74s/it]

Train loss : 5.7630 | Test loss : 5.7969 | Train time : 196.06 s | Lr : 0.00080883
{'loss': 5.7969389867179, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 68%|██████▊   | 34/50 [2:12:09<1:01:40, 231.26s/it]

Train loss : 5.7608 | Test loss : 5.7944 | Train time : 192.22 s | Lr : 0.00076839
{'loss': 5.79443914075441, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 70%|███████   | 35/50 [2:16:05<58:09, 232.65s/it]  

Train loss : 5.7586 | Test loss : 5.7974 | Train time : 195.86 s | Lr : 0.00072997
{'loss': 5.797447959079018, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 72%|███████▏  | 36/50 [2:20:00<54:27, 233.42s/it]

Train loss : 5.7575 | Test loss : 5.7934 | Train time : 195.05 s | Lr : 0.00069347
{'loss': 5.793442557129679, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 74%|███████▍  | 37/50 [2:23:54<50:38, 233.70s/it]

Train loss : 5.7545 | Test loss : 5.7933 | Train time : 194.14 s | Lr : 0.00065880
{'loss': 5.793302403220648, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 76%|███████▌  | 38/50 [2:27:40<46:17, 231.47s/it]

Train loss : 5.7524 | Test loss : 5.7917 | Train time : 186.26 s | Lr : 0.00062586
{'loss': 5.791708529750003, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 78%|███████▊  | 39/50 [2:31:34<42:34, 232.19s/it]

Train loss : 5.7517 | Test loss : 5.7925 | Train time : 193.88 s | Lr : 0.00059457
{'loss': 5.792530470256564, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 80%|████████  | 40/50 [2:35:30<38:52, 233.26s/it]

Train loss : 5.7508 | Test loss : 5.7921 | Train time : 195.45 s | Lr : 0.00056484
{'loss': 5.792130029654201, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 82%|████████▏ | 41/50 [2:39:25<35:03, 233.77s/it]

Train loss : 5.7487 | Test loss : 5.7927 | Train time : 195.14 s | Lr : 0.00053660
{'loss': 5.792679901364483, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 84%|████████▍ | 42/50 [2:43:11<30:50, 231.37s/it]

Train loss : 5.7465 | Test loss : 5.7911 | Train time : 185.75 s | Lr : 0.00050977
{'loss': 5.7911017816278, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 86%|████████▌ | 43/50 [2:47:01<26:57, 231.01s/it]

Train loss : 5.7462 | Test loss : 5.7899 | Train time : 189.98 s | Lr : 0.00048428
{'loss': 5.7899137388301805, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 88%|████████▊ | 44/50 [2:50:58<23:16, 232.75s/it]

Train loss : 5.7449 | Test loss : 5.7891 | Train time : 196.42 s | Lr : 0.00046006
{'loss': 5.789105807678609, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 90%|█████████ | 45/50 [2:54:54<19:28, 233.70s/it]

Train loss : 5.7427 | Test loss : 5.7905 | Train time : 195.68 s | Lr : 0.00043706
{'loss': 5.790540665010862, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 92%|█████████▏| 46/50 [2:58:39<15:24, 231.14s/it]

Train loss : 5.7421 | Test loss : 5.7908 | Train time : 184.84 s | Lr : 0.00041521
{'loss': 5.790824081324324, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 94%|█████████▍| 47/50 [3:02:28<11:31, 230.62s/it]

Train loss : 5.7409 | Test loss : 5.7896 | Train time : 189.52 s | Lr : 0.00039445
{'loss': 5.789599436747877, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 96%|█████████▌| 48/50 [3:06:20<07:42, 231.08s/it]

Train loss : 5.7389 | Test loss : 5.7884 | Train time : 192.44 s | Lr : 0.00037473
{'loss': 5.788402792773669, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


 98%|█████████▊| 49/50 [3:10:17<03:52, 232.81s/it]

Train loss : 5.7379 | Test loss : 5.7874 | Train time : 197.03 s | Lr : 0.00035599
{'loss': 5.787425385245794, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}


100%|██████████| 50/50 [3:14:02<00:00, 232.84s/it]

Train loss : 5.7380 | Test loss : 5.7874 | Train time : 184.51 s | Lr : 0.00033819
{'loss': 5.787433714806279, 'accuracy': 0.615, 'bleu_1': 0.034013043478259536, 'bleu_2': 2.4360880535736356e-155, 'bleu_3': 2.5050322592968467e-204, 'bleu_4': 6.641550455034804e-232}





Complete


In [7]:
encoder = EncoderCNN(
    output_size=image_size
)
decoder = DecoderRNN(
    embed_size=embed_size,
    vocab_size=tokenizer.vocab_size,
    hidden_size=hidden_size,
    input_size=image_size,
    num_layers=1
)
decoder.use_hidden = True
image_to_text_model = ImageToTextModel(
    encoder=encoder,
    decoder=decoder
)
loss_func= nn.CrossEntropyLoss(ignore_index=0)
model_path, train_log = util.train_eval(
    trainloader=trainloader,
    testloader=testloader,
    model=image_to_text_model,
    train_func=train,
    test_func=test,
    lossf=loss_func,
    num_epochs=50,
    lr=2.5e-3,
    gamma=0.95,
    log_step=1,
    warmup_nepochs=10,
    warmup_lr=1e-3,
    warmup_gamma=1.1,
    save=True,
    mixed_train = True,
    mixed_eval = False,
    save_path="checkpoint/teset_model",
    optimizer_type=torch.optim.Adam,
    device=device,
    metadata_extra={
        "batch_size" : BATCH_SIZE,
        "dataset_name" : "Flickr8k",
        "use_hidden" : decoder.use_hidden
    },
    log_metric=True
)

Start train


  2%|▏         | 1/50 [04:00<3:16:08, 240.16s/it]

Train loss : 5.6786 | Test loss : 6.2473 | Train time : 199.30 s | Lr : 0.00100000
{'loss': 6.247293586972393, 'accuracy': 0.615, 'bleu_1': 0.10431304347826342, 'bleu_2': 0.05707525625403544, 'bleu_3': 0.004600745872437064, 'bleu_4': 0.001765013283302346}


  4%|▍         | 2/50 [07:50<3:07:32, 234.43s/it]

Train loss : 4.2373 | Test loss : 6.8711 | Train time : 189.01 s | Lr : 0.00110000
{'loss': 6.871086452580705, 'accuracy': 0.615, 'bleu_1': 0.10893913043478494, 'bleu_2': 0.06064154050944506, 'bleu_3': 0.009070156255103356, 'bleu_4': 0.003170201313093507}


  6%|▌         | 3/50 [11:34<2:59:56, 229.72s/it]

Train loss : 3.8972 | Test loss : 7.1387 | Train time : 182.52 s | Lr : 0.00121000
{'loss': 7.13874855524377, 'accuracy': 0.615, 'bleu_1': 0.10398260869565458, 'bleu_2': 0.05772242864502872, 'bleu_3': 0.009135295948403833, 'bleu_4': 0.003351903934542744}


  8%|▊         | 4/50 [15:18<2:54:18, 227.36s/it]

Train loss : 3.7031 | Test loss : 7.6400 | Train time : 182.49 s | Lr : 0.00133100
{'loss': 7.640007290659072, 'accuracy': 0.615, 'bleu_1': 0.10229130434782904, 'bleu_2': 0.05491666921398858, 'bleu_3': 0.010718897952299895, 'bleu_4': 0.004469809823080863}


 10%|█         | 5/50 [19:01<2:49:21, 225.80s/it]

Train loss : 3.5652 | Test loss : 7.6191 | Train time : 182.37 s | Lr : 0.00146410
{'loss': 7.619073572038095, 'accuracy': 0.615, 'bleu_1': 0.10543478260869829, 'bleu_2': 0.05940824096832533, 'bleu_3': 0.012783977989407362, 'bleu_4': 0.005074634137409146}


 12%|█▏        | 6/50 [22:45<2:45:04, 225.10s/it]

Train loss : 3.4642 | Test loss : 7.8342 | Train time : 182.52 s | Lr : 0.00161051
{'loss': 7.83419814894471, 'accuracy': 0.615, 'bleu_1': 0.10600869565217663, 'bleu_2': 0.05887736104755478, 'bleu_3': 0.012828657248513977, 'bleu_4': 0.00473046165405572}


 14%|█▍        | 7/50 [26:33<2:41:58, 226.01s/it]

Train loss : 3.3816 | Test loss : 8.0706 | Train time : 186.46 s | Lr : 0.00177156
{'loss': 8.070551244518425, 'accuracy': 0.615, 'bleu_1': 0.11126956521739398, 'bleu_2': 0.06334843252582066, 'bleu_3': 0.014599011437384662, 'bleu_4': 0.006103061099040205}


 16%|█▌        | 8/50 [30:16<2:37:41, 225.26s/it]

Train loss : 3.3200 | Test loss : 8.1742 | Train time : 182.59 s | Lr : 0.00194872
{'loss': 8.174180676665488, 'accuracy': 0.615, 'bleu_1': 0.09989565217391572, 'bleu_2': 0.05429650730551136, 'bleu_3': 0.011585048768724909, 'bleu_4': 0.004509160212234594}


 18%|█▊        | 9/50 [34:00<2:33:35, 224.76s/it]

Train loss : 3.2665 | Test loss : 8.4408 | Train time : 182.51 s | Lr : 0.00214359
{'loss': 8.440843437291399, 'accuracy': 0.615, 'bleu_1': 0.10930000000000248, 'bleu_2': 0.061781039402622334, 'bleu_3': 0.01264641509819853, 'bleu_4': 0.0047838974137251146}


 20%|██        | 10/50 [37:52<2:31:17, 226.95s/it]

Train loss : 3.2254 | Test loss : 8.5871 | Train time : 190.35 s | Lr : 0.00235795
{'loss': 8.587101103384283, 'accuracy': 0.615, 'bleu_1': 0.11034347826087207, 'bleu_2': 0.06307705141419687, 'bleu_3': 0.014589309944821363, 'bleu_4': 0.0063240327061607855}


 22%|██▏       | 11/50 [41:43<2:28:26, 228.36s/it]

Train loss : 3.1848 | Test loss : 8.7136 | Train time : 190.28 s | Lr : 0.00250000
{'loss': 8.71359702001644, 'accuracy': 0.615, 'bleu_1': 0.10840434782608957, 'bleu_2': 0.06179133215852094, 'bleu_3': 0.013503662877584673, 'bleu_4': 0.005809241097563724}


 24%|██▍       | 12/50 [45:27<2:23:39, 226.83s/it]

Train loss : 3.1347 | Test loss : 8.6922 | Train time : 182.33 s | Lr : 0.00237500
{'loss': 8.692210131053683, 'accuracy': 0.615, 'bleu_1': 0.11139130434782865, 'bleu_2': 0.06319330040021381, 'bleu_3': 0.01287907233943635, 'bleu_4': 0.005262314055011575}


 26%|██▌       | 13/50 [49:18<2:20:44, 228.23s/it]

Train loss : 3.0877 | Test loss : 8.9900 | Train time : 189.93 s | Lr : 0.00225625
{'loss': 8.989996270288396, 'accuracy': 0.615, 'bleu_1': 0.11081304347826317, 'bleu_2': 0.06323709812778351, 'bleu_3': 0.014084482471408524, 'bleu_4': 0.005692768410811251}


 28%|██▊       | 14/50 [53:08<2:17:13, 228.72s/it]

Train loss : 3.0473 | Test loss : 9.0298 | Train time : 188.70 s | Lr : 0.00214344
{'loss': 9.02979932857465, 'accuracy': 0.615, 'bleu_1': 0.11109565217391577, 'bleu_2': 0.06340608007000198, 'bleu_3': 0.013300297818389135, 'bleu_4': 0.005373937971261754}


 30%|███       | 15/50 [57:03<2:14:28, 230.52s/it]

Train loss : 3.0092 | Test loss : 9.1956 | Train time : 193.16 s | Lr : 0.00203627
{'loss': 9.195624194567717, 'accuracy': 0.615, 'bleu_1': 0.1142434782608722, 'bleu_2': 0.06523724639013033, 'bleu_3': 0.01443101414179508, 'bleu_4': 0.005038902412530196}


 32%|███▏      | 16/50 [1:01:02<2:12:09, 233.22s/it]

Train loss : 2.9741 | Test loss : 9.2074 | Train time : 198.57 s | Lr : 0.00193445
{'loss': 9.207419105722934, 'accuracy': 0.615, 'bleu_1': 0.1137478260869594, 'bleu_2': 0.06559477924522221, 'bleu_3': 0.014820965454642732, 'bleu_4': 0.006133015619561319}


 34%|███▍      | 17/50 [1:05:00<2:08:58, 234.50s/it]

Train loss : 2.9426 | Test loss : 9.3591 | Train time : 195.80 s | Lr : 0.00183773
{'loss': 9.359114514121526, 'accuracy': 0.615, 'bleu_1': 0.11290000000000254, 'bleu_2': 0.06498401809004548, 'bleu_3': 0.01513752503677603, 'bleu_4': 0.005897550415081777}


 36%|███▌      | 18/50 [1:08:50<2:04:24, 233.27s/it]

Train loss : 2.9112 | Test loss : 9.4657 | Train time : 189.42 s | Lr : 0.00174584
{'loss': 9.465674170964881, 'accuracy': 0.615, 'bleu_1': 0.11147826086956772, 'bleu_2': 0.06444228930868721, 'bleu_3': 0.01574289835928682, 'bleu_4': 0.0067110229629494525}


 38%|███▊      | 19/50 [1:12:42<2:00:21, 232.94s/it]

Train loss : 2.8840 | Test loss : 9.5161 | Train time : 191.22 s | Lr : 0.00165855
{'loss': 9.516107957574386, 'accuracy': 0.615, 'bleu_1': 0.11262173913043733, 'bleu_2': 0.06423310627314127, 'bleu_3': 0.013383715166367355, 'bleu_4': 0.005060601093286057}


 40%|████      | 20/50 [1:16:37<1:56:44, 233.49s/it]

Train loss : 2.8599 | Test loss : 9.6331 | Train time : 193.81 s | Lr : 0.00157562
{'loss': 9.633113185061685, 'accuracy': 0.615, 'bleu_1': 0.10369130434782874, 'bleu_2': 0.05801298637915888, 'bleu_3': 0.013808518427599267, 'bleu_4': 0.005486342740732564}


 42%|████▏     | 21/50 [1:20:30<1:52:48, 233.40s/it]

Train loss : 2.8327 | Test loss : 9.7534 | Train time : 192.23 s | Lr : 0.00149684
{'loss': 9.753424885906751, 'accuracy': 0.615, 'bleu_1': 0.11016086956521978, 'bleu_2': 0.06333559977802676, 'bleu_3': 0.014716457141516162, 'bleu_4': 0.00634714507106797}


 44%|████▍     | 22/50 [1:24:23<1:48:51, 233.26s/it]

Train loss : 2.8099 | Test loss : 9.8510 | Train time : 191.90 s | Lr : 0.00142200
{'loss': 9.850961950760853, 'accuracy': 0.615, 'bleu_1': 0.11173478260869833, 'bleu_2': 0.06448553325837475, 'bleu_3': 0.01593832141756544, 'bleu_4': 0.0066637238776096185}


 46%|████▌     | 23/50 [1:28:17<1:45:00, 233.36s/it]

Train loss : 2.7884 | Test loss : 9.8195 | Train time : 192.32 s | Lr : 0.00135090
{'loss': 9.819537343858164, 'accuracy': 0.615, 'bleu_1': 0.11262173913043758, 'bleu_2': 0.06477306493420479, 'bleu_3': 0.014722240503831049, 'bleu_4': 0.005919100142602602}


 48%|████▊     | 24/50 [1:32:08<1:40:52, 232.78s/it]

Train loss : 2.7670 | Test loss : 9.9454 | Train time : 190.24 s | Lr : 0.00128336
{'loss': 9.945360666588892, 'accuracy': 0.615, 'bleu_1': 0.1132478260869593, 'bleu_2': 0.06515064333959579, 'bleu_3': 0.015672252501494996, 'bleu_4': 0.0064018245249156365}


 50%|█████     | 25/50 [1:36:02<1:37:10, 233.21s/it]

Train loss : 2.7473 | Test loss : 10.0096 | Train time : 193.05 s | Lr : 0.00121919
{'loss': 10.009567345244974, 'accuracy': 0.615, 'bleu_1': 0.11468260869565497, 'bleu_2': 0.06615666401704577, 'bleu_3': 0.015723468604680164, 'bleu_4': 0.006575266418654408}


 52%|█████▏    | 26/50 [1:39:55<1:33:16, 233.17s/it]

Train loss : 2.7305 | Test loss : 10.0272 | Train time : 192.45 s | Lr : 0.00115823
{'loss': 10.027190594733517, 'accuracy': 0.615, 'bleu_1': 0.1138739130434809, 'bleu_2': 0.06586262988383867, 'bleu_3': 0.015886059672859318, 'bleu_4': 0.006383453490734003}


 54%|█████▍    | 27/50 [1:43:45<1:28:55, 231.96s/it]

Train loss : 2.7139 | Test loss : 10.1617 | Train time : 187.64 s | Lr : 0.00110032
{'loss': 10.1617013472545, 'accuracy': 0.615, 'bleu_1': 0.11295652173913293, 'bleu_2': 0.06483383557181947, 'bleu_3': 0.01656687300168692, 'bleu_4': 0.006797036299830124}


 56%|█████▌    | 28/50 [1:47:36<1:24:57, 231.71s/it]

Train loss : 2.6934 | Test loss : 10.1890 | Train time : 190.10 s | Lr : 0.00104530
{'loss': 10.189047246039669, 'accuracy': 0.615, 'bleu_1': 0.11206086956521988, 'bleu_2': 0.06470773458571288, 'bleu_3': 0.015851364828178827, 'bleu_4': 0.006420257280397095}


 58%|█████▊    | 29/50 [1:51:32<1:21:32, 232.96s/it]

Train loss : 2.6802 | Test loss : 10.2584 | Train time : 194.36 s | Lr : 0.00099304
{'loss': 10.258405914789513, 'accuracy': 0.615, 'bleu_1': 0.11211304347826358, 'bleu_2': 0.06455777645190432, 'bleu_3': 0.0150584952529573, 'bleu_4': 0.006045449830229635}


 60%|██████    | 30/50 [1:55:25<1:17:41, 233.10s/it]

Train loss : 2.6643 | Test loss : 10.2230 | Train time : 192.09 s | Lr : 0.00094338
{'loss': 10.223022279860098, 'accuracy': 0.615, 'bleu_1': 0.11502173913043759, 'bleu_2': 0.06657398352088804, 'bleu_3': 0.01641786817974751, 'bleu_4': 0.006701787299363676}


 62%|██████▏   | 31/50 [1:59:12<1:13:11, 231.14s/it]

Train loss : 2.6514 | Test loss : 10.4428 | Train time : 184.96 s | Lr : 0.00089621
{'loss': 10.442837220204028, 'accuracy': 0.615, 'bleu_1': 0.11539565217391559, 'bleu_2': 0.06606032379014927, 'bleu_3': 0.015506823621539897, 'bleu_4': 0.006598507264518569}


 64%|██████▍   | 32/50 [2:03:06<1:09:37, 232.06s/it]

Train loss : 2.6352 | Test loss : 10.3525 | Train time : 192.71 s | Lr : 0.00085140
{'loss': 10.352491294281393, 'accuracy': 0.615, 'bleu_1': 0.11366086956522015, 'bleu_2': 0.06553115632584322, 'bleu_3': 0.01649417033960021, 'bleu_4': 0.006690897462274008}


 66%|██████▌   | 33/50 [2:06:59<1:05:50, 232.40s/it]

Train loss : 2.6250 | Test loss : 10.3924 | Train time : 191.84 s | Lr : 0.00080883
{'loss': 10.392352345623548, 'accuracy': 0.615, 'bleu_1': 0.11366956521739408, 'bleu_2': 0.06568991212173905, 'bleu_3': 0.016157708414076512, 'bleu_4': 0.006288011794528766}


 68%|██████▊   | 34/50 [2:10:50<1:01:52, 232.04s/it]

Train loss : 2.6116 | Test loss : 10.5406 | Train time : 189.63 s | Lr : 0.00076839
{'loss': 10.54057532926149, 'accuracy': 0.6152, 'bleu_1': 0.11385217391304621, 'bleu_2': 0.06559171408182525, 'bleu_3': 0.016214076000413892, 'bleu_4': 0.006778178276510567}


 70%|███████   | 35/50 [2:14:34<57:24, 229.64s/it]  

Train loss : 2.6005 | Test loss : 10.5314 | Train time : 182.54 s | Lr : 0.00072997
{'loss': 10.53138794476473, 'accuracy': 0.615, 'bleu_1': 0.1137260869565244, 'bleu_2': 0.06545853629641343, 'bleu_3': 0.01602373229580275, 'bleu_4': 0.006782981405102419}


 72%|███████▏  | 36/50 [2:18:23<53:32, 229.49s/it]

Train loss : 2.5887 | Test loss : 10.4854 | Train time : 187.86 s | Lr : 0.00069347
{'loss': 10.485440423217002, 'accuracy': 0.615, 'bleu_1': 0.11570434782608983, 'bleu_2': 0.06657142193640675, 'bleu_3': 0.01647023786732707, 'bleu_4': 0.006516561189920044}


 74%|███████▍  | 37/50 [2:22:16<49:56, 230.49s/it]

Train loss : 2.5788 | Test loss : 10.5922 | Train time : 191.57 s | Lr : 0.00065880
{'loss': 10.592220113247256, 'accuracy': 0.615, 'bleu_1': 0.11552173913043766, 'bleu_2': 0.0663584114247409, 'bleu_3': 0.016432869986303875, 'bleu_4': 0.006801994408968174}


 76%|███████▌  | 38/50 [2:26:02<45:50, 229.21s/it]

Train loss : 2.5655 | Test loss : 10.5996 | Train time : 183.98 s | Lr : 0.00062586
{'loss': 10.599646012994308, 'accuracy': 0.615, 'bleu_1': 0.11607391304348107, 'bleu_2': 0.06700125109055567, 'bleu_3': 0.016500068517499744, 'bleu_4': 0.006730372974141183}


 76%|███████▌  | 38/50 [2:28:14<46:48, 234.07s/it]


KeyboardInterrupt: 

In [9]:
import util
util.save_model_chunks(image_to_text_model, "checkpoint/un_complete", "model", 95*1024*1024)