In [1]:
import os
import sys

from time import strftime, localtime
import logging
import random
import math

from PIL import Image
import numpy as np
import pickle
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import transformers
from transformers import BertTokenizer, BertModel
from transformers import ViTFeatureExtractor, ViTModel

from sklearn import metrics
import spacy
from nltk.corpus import wordnet as wn

from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)

seed = 777

logger = logging.getLogger()
logger.setLevel(logging.INFO)
transformers.logging.set_verbosity_error()

pretrained_bert_name = '/hy-tmp/models/bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(pretrained_bert_name)
max_seq_len = 100

pretrained_vit_name = '/hy-tmp/models/vit-base-patch16-224'
feature_extractor = ViTFeatureExtractor.from_pretrained(pretrained_vit_name)
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
crop_size = (feature_extractor.size['height'], feature_extractor.size['width'])
# crop_size = feature_extractor.size
train_transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )
val_transforms = Compose(
        [
            Resize(crop_size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ]
    )

img_dir = '/hy-tmp/data/dataset_image'
train_file = '/hy-tmp/data/processed_train.data'
valid_file = '/hy-tmp/data/processed_valid.data'
test_file = '/hy-tmp/data/processed_test.data'

model_name = 'CM_ATTENTION4'
check_point_path = '/hy-tmp/models'
log_file = f'/root/logs/{model_name}-{strftime("%y%m%d-%H%M", localtime())}.log'
result_file = f'/root/results/{model_name}_predicts.txt'
model_checkpoint = f'{check_point_path}/best_state/{model_name}'

logger.addHandler(logging.StreamHandler(sys.stdout))
logger.addHandler(logging.FileHandler(log_file))

inputs_cols = ['labels', 'box_vit', 'images', 'text_indices', 'text_in_img_indices', 'text_merge_indices', 'attribute_object_indices']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

bert_dim = 768
vit_dim = 768
polarities_dim = 2
hidden_dim = 512
batch_size = 16
dropout = 0.1
patch_size = 16
val_step = 40

sp_nlp = spacy.load('en_core_web_sm')
filenames = os.listdir(img_dir)

In [2]:
def get_doc(text, max_len=0):
    token_list = []
    text = text.lower().strip()
    
    document = sp_nlp(text)
    spacy_token = [str(x) for x in document]
    spacy_len = len(spacy_token)
    
    # if max_len > 0:
    #     if spacy_len > max_len:
    #         spacy_token = spacy_token[:max_len]

    s = ''
    for token in spacy_token:
        s = s + ' ' + token
    # document = sp_nlp(s)
    # spacy_token = [str(x) for x in document]
    return document, s.strip(), spacy_token

def pad_and_truncate(sequence, maxlen, dtype='int64', padding='post', truncating='post', value=0):
    x = (np.ones(maxlen) * value).astype(dtype)
    if truncating == 'pre':
        trunc = sequence[-maxlen:]
    else:
        trunc = sequence[:maxlen]
    trunc = np.asarray(trunc, dtype=dtype)
    if padding == 'post':
        x[:len(trunc)] = trunc
    else:
        x[-len(trunc):] = trunc
    return x

class attention_Dataset(Dataset):
    def __init__(self, data_file, img_dir, transform=None):
        self.transform = transform
        self.img_dir = img_dir
        
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_bert_name)
        data = pickle.load(open(data_file,'rb'))

        print("{}.data".format(data_file))
        all_data = []
        for key,value in data.items():
            img_id = value['id']
            label = int(value['label'])
            image_file = img_id+'.jpg'
            
            text = value['text']
            attribute_objects = value['attribute_objects']
            text_in_img = value['text_in_img']
            box_vit = value["box_vit"]
            box_vit = [x.numpy() for x in box_vit]
            
            data_ = {
                'img_id': img_id,
                'label':label,
                'box_vit':box_vit,
                'image_file': image_file,
                'text':text,
                'text_in_img':text_in_img,
                'attribute_objects':attribute_objects,
            }
            all_data.append(data_)
        self.all_data = all_data
     
    def text_to_indices(self, text, text_pair=None):
        if text_pair is None:
            encoded_dict = self.tokenizer(
                                text,                      # Sentence to encode.
                                add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                                padding = 'max_length',
                                truncation = True,
                                max_length = max_seq_len,    # Pad & truncate all sentences.
                                return_attention_mask = True,   # Construct attn. masks.
                                return_tensors = 'np',     # Return pytorch tensors.
                                return_length = True,
                                is_split_into_words = True,
                           )

        else:
            encoded_dict = self.tokenizer(
                            text,                      # Sentence to encode.
                            text_pair,
                            add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                            padding = 'max_length',
                            truncation = 'longest_first',
                            max_length = max_seq_len,    # Pad & truncate all sentences.
                            return_attention_mask = True,   # Construct attn. masks.
                            return_tensors = 'np',     # Return pytorch tensors.
                            return_length = True,
                            is_split_into_words = True,
                       )
        return encoded_dict
    
    def my_collate_fn(self, data):
        # use bert tokenizer, no graph returned
        b_img_id = []
        b_label = []
        b_box_vit = []
        b_image = []
        b_text_indices = []
        b_text_in_img_indices = []
        b_text_merge_indices = []
        b_attribute_object_indices = []

        for item in data:
            b_img_id.append(item['img_id'])
            b_label.append(item['label'])
            b_box_vit.append(item['box_vit'])
            b_image.append(item['image'])

            text = item['text']
            text_in_img = item['text_in_img']
            attribute_objects = item['attribute_objects']
            
            attribute_objects_token = []
            for attribute, _object in attribute_objects:
                attribute_objects_token += [attribute, _object]

            text_doc,_,text_token = get_doc(text)
            text_in_img_doc,_,text_in_img_token = get_doc(text_in_img)
            if not text_token:
                text_token = ['']
            if not text_in_img_token:
                text_in_img_token = ['']
            
            b_text_indices.append(text_token)
            b_text_in_img_indices.append(text_in_img_token)
            b_attribute_object_indices.append(attribute_objects_token)
        
        text_encoded_dict = self.text_to_indices(b_text_indices)
        text_in_img_encoded_dict = self.text_to_indices(b_text_in_img_indices)
        text_merge_encoded_dict = self.text_to_indices(b_text_indices, b_text_in_img_indices)
        attribute_object_encoded_dict = self.text_to_indices(b_attribute_object_indices)
        
        return {
                    'img_ids': b_img_id,
                    'labels': torch.tensor(b_label),
                    'box_vit':torch.tensor(np.array(b_box_vit)),
                    'images':torch.stack(b_image, dim=0),
                    'text_indices':torch.tensor(text_encoded_dict.input_ids),
                    'text_in_img_indices':torch.tensor(text_in_img_encoded_dict.input_ids),
                    'text_merge_indices':torch.tensor(text_merge_encoded_dict.input_ids),
                    'attribute_object_indices':torch.tensor(attribute_object_encoded_dict.input_ids),
                    }
    
    def __getitem__(self, index):
        img_id = self.all_data[index]['img_id']
        image_file = self.all_data[index]['image_file']
        label = self.all_data[index]['label']
        box_vit = self.all_data[index]['box_vit']
        image = Image.open(os.path.join(self.img_dir, image_file))
        if self.transform:
            image = self.transform(image)
        text = self.all_data[index]['text']
        text_in_img = self.all_data[index]['text_in_img']
        attribute_objects = self.all_data[index]['attribute_objects']
        
        return {
            'img_id': img_id,
            'label':label,
            'box_vit':box_vit,
            'image': image,
            'text':text,
            'text_in_img':text_in_img,
            'attribute_objects':attribute_objects,
        }

    def __len__(self):
        return len(self.all_data)

In [3]:
train_dataset = attention_Dataset(data_file=train_file, img_dir=img_dir, transform=train_transforms)
valid_dataset = attention_Dataset(data_file=valid_file, img_dir=img_dir, transform=val_transforms)
test_dataset = attention_Dataset(data_file=test_file, img_dir=img_dir, transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.my_collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=valid_dataset.my_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=test_dataset.my_collate_fn)

print(train_dataset.__len__(), valid_dataset.__len__(), test_dataset.__len__())

/hy-tmp/data/processed_train.data.data
/hy-tmp/data/processed_valid.data.data
/hy-tmp/data/processed_test.data.data
19816 2410 2409


In [4]:
class SimplifiedScaledDotProductAttention(nn.Module):
    '''
    Scaled dot-product attention
    '''

    def __init__(self, d_model, h, dropout=.1):
        '''
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        '''
        super(SimplifiedScaledDotProductAttention, self).__init__()

        self.d_model = d_model
        self.d_k = d_model//h
        self.d_v = d_model//h
        self.h = h

        self.fc_o = nn.Linear(h * self.d_v, d_model)
        self.dropout=nn.Dropout(dropout)



        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        '''
        Computes
        :param queries: Queries (b_s, nq, d_model)
        :param keys: Keys (b_s, nk, d_model)
        :param values: Values (b_s, nk, d_model)
        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        :return:
        '''
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        q = queries.view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = keys.view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = values.view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)
        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att = att.masked_fill(attention_mask, -np.inf)
        
        # we activate the negative att
        att = torch.abs(att)
        
        att = torch.softmax(att, -1)
        
        att=self.dropout(att)

        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)
        return out

class DynamicLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0,
                 bidirectional=False, only_use_last_hidden_state=False, rnn_type = 'LSTM'):
        """
        LSTM which can hold variable length sequence, use like TensorFlow's RNN(input, length...).

        :param input_size:The number of expected features in the input x
        :param hidden_size:The number of features in the hidden state h
        :param num_layers:Number of recurrent layers.
        :param bias:If False, then the layer does not use bias weights b_ih and b_hh. Default: True
        :param batch_first:If True, then the input and output tensors are provided as (batch, seq, feature)
        :param dropout:If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
        :param bidirectional:If True, becomes a bidirectional RNN. Default: False
        :param rnn_type: {LSTM, GRU, RNN}
        """
        super(DynamicLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.batch_first = batch_first
        self.dropout = dropout
        self.bidirectional = bidirectional
        self.only_use_last_hidden_state = only_use_last_hidden_state
        self.rnn_type = rnn_type
        
        if self.rnn_type == 'LSTM': 
            self.RNN = nn.LSTM(
                input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
                bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)  
        elif self.rnn_type == 'GRU':
            self.RNN = nn.GRU(
                input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
                bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)
        elif self.rnn_type == 'RNN':
            self.RNN = nn.RNN(
                input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
                bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional)
        

    def forward(self, x, x_len, h0=None):
        """
        sequence -> sort -> pad and pack ->process using RNN -> unpack ->unsort

        :param x: sequence embedding vectors
        :param x_len: numpy/tensor list
        :return:
        """
        """sort"""
        x_sort_idx = torch.argsort(-x_len)
        x_unsort_idx = torch.argsort(x_sort_idx).long()
        x_len = x_len[x_sort_idx]
        x = x[x_sort_idx.long()]
        """pack"""
        x_emb_p = torch.nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=self.batch_first)
        
        if self.rnn_type == 'LSTM':
            if h0 is None: 
                out_pack, (ht, ct) = self.RNN(x_emb_p, None)
            else:
                out_pack, (ht, ct) = self.RNN(x_emb_p, (h0, h0))
        else: 
            if h0 is None:
                out_pack, ht = self.RNN(x_emb_p, None)
            else:
                out_pack, ht = self.RNN(x_emb_p, h0)
            ct = None
        """unsort: h"""
        ht = torch.transpose(ht, 0, 1)[
            x_unsort_idx]  
        ht = torch.transpose(ht, 0, 1)

        if self.only_use_last_hidden_state:
            return ht
        else:
            """unpack: out"""
            out = torch.nn.utils.rnn.pad_packed_sequence(out_pack, batch_first=self.batch_first)
            out = out[0]  #
            out = out[x_unsort_idx]
            """unsort: out c"""
            if self.rnn_type =='LSTM':
                ct = torch.transpose(ct, 0, 1)[
                    x_unsort_idx]
                ct = torch.transpose(ct, 0, 1)

            return out, (ht, ct)
        
class CM_ATTENTION4(nn.Module):
    def __init__(self, pretrained_bert_name, pretrained_vit_name):
        super(CM_ATTENTION4, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_bert_name)
        self.vit = ViTModel.from_pretrained(pretrained_vit_name)
        self.attention1 = SimplifiedScaledDotProductAttention(d_model=bert_dim, h=2)
        self.attention2 = SimplifiedScaledDotProductAttention(d_model=bert_dim, h=2)
        
        self.lstm1 = DynamicLSTM(bert_dim, hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
        self.lstm2 = DynamicLSTM(vit_dim, hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
        
        self.fc1 = nn.Linear(4*hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, polarities_dim)
        
        self.dropout=nn.Dropout(dropout)
        
        self.params = []
        for child in self.children():
            if child != self.bert and child != self.vit:
                self.params += child.parameters()
    
    def forward(self, inputs):
        labels, box_vit, images, text_indices, text_in_img_indices, text_merge_indices, attribute_object_indices = inputs
        bs = labels.shape[0]
        box_vit_len = torch.tensor([10]*bs)
        img_patch_len = torch.tensor([patch_size]*bs)
        text_len = torch.sum(text_indices != 0, dim=-1)
        text_in_img_len = torch.sum(text_in_img_indices != 0, dim=-1)
        text_merge_len = torch.sum(text_merge_indices != 0, dim=-1)
        attribute_object_len = torch.sum(attribute_object_indices != 0, dim=-1)
        
        text_out = self.bert(text_merge_indices,  output_hidden_states=False)
        image_out = self.vit(images, output_hidden_states=False)
        
        atte_text = self.attention1(queries=image_out.last_hidden_state, keys=text_out.last_hidden_state,
                                     values=text_out.last_hidden_state) + image_out.last_hidden_state
        atte_image = self.attention2(queries=text_out.last_hidden_state, keys=image_out.last_hidden_state,
                                     values=image_out.last_hidden_state) + text_out.last_hidden_state
        
        atte_text_out, (_, _) = self.lstm1(atte_text, img_patch_len.cpu())
        atte_image_out, (_, _) = self.lstm2(atte_image, text_merge_len.cpu())
        
        feature = torch.cat([atte_text_out[:,0,:], atte_image_out[:,0,:]], dim = 1)
        feature = self.dropout(feature)
        
        output = self.fc2(F.relu(self.fc1(feature)))
        
        return output

    def reset_params(self):
        layers = [self.lstm1, self.lstm2, self.fc1, self.fc2]
        for layer in layers:
            modules = layer.modules()
            for m in modules:
                if isinstance(m, nn.Conv2d):
                    init.kaiming_normal_(m.weight, mode='fan_out')
                    if m.bias is not None:
                        init.constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                    init.constant_(m.weight, 1)
                    init.constant_(m.bias, 0)
                elif isinstance(m, nn.Linear):
                    init.normal_(m.weight, std=0.001)
                    if m.bias is not None:
                        init.constant_(m.bias, 0)

In [5]:
def eval_(model, data_loader, save_path=None):
    n_correct, n_total = 0, 0
    t_targets_all, t_outputs_all = None, None
    model.eval()
    
    with torch.no_grad():
        for i_batch, t_batch in enumerate(data_loader):
            t_inputs = [t_batch[col].to(device)   for col in inputs_cols]
            t_targets = t_batch['labels'].to(device)
            t_img_ids = t_batch['img_ids']
            
            t_outputs = model(t_inputs)

            n_correct += (torch.argmax(t_outputs, -1) == t_targets).sum().item()
            n_total += len(t_outputs)

            if t_targets_all is None:
                t_targets_all = t_targets
                t_outputs_all = t_outputs
                t_img_ids_all = t_img_ids
            else:
                t_targets_all = torch.cat((t_targets_all, t_targets), dim=0)
                t_outputs_all = torch.cat((t_outputs_all, t_outputs), dim=0)
                t_img_ids_all += t_img_ids
    
    if save_path:
        with open(save_path,'w',encoding='utf-8') as fout:
            img_ids_all = t_img_ids_all
            predicts_all = torch.argmax(t_outputs_all, -1).cpu().numpy().tolist()
            labels_all = t_targets_all.cpu().numpy().tolist()
            outputs_all = t_outputs_all.cpu().numpy().tolist()
            assert len(img_ids_all) == len(predicts_all) == len(labels_all) == len(outputs_all)
            
            for i in range(len(img_ids_all)):
                img_id = img_ids_all[i]
                predict = predicts_all[i]
                label = labels_all[i]
                output = outputs_all[i]
                fout.write(f'{str(img_id)} {str(predict)} {str(label)} {str(output)} \n')

    acc = n_correct / n_total
    f1 = metrics.f1_score(t_targets_all.cpu(), torch.argmax(t_outputs_all, -1).cpu())
    precision =  metrics.precision_score(t_targets_all.cpu(),torch.argmax(t_outputs_all, -1).cpu())
    recall = metrics.recall_score(t_targets_all.cpu(),torch.argmax(t_outputs_all, -1).cpu())
    return acc, f1 ,precision, recall

def train(model, train_data_loader, val_data_loader, test_data_loader):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam([{'params':model.bert.parameters(),'lr':2e-5},
                                  {'params':model.vit.parameters(),'lr':2e-5},
                                  {'params':model.params,'lr':1e-3},
                                 ], lr=1e-3, weight_decay=1e-5)
    global_step = 0
    max_val_acc = 0
    max_val_f1 = 0
    max_val_epoch = 0
    
    model.reset_params()
    
    for i_epoch in range(100):
        logger.info('>' * 100)
        logger.info('epoch: {}'.format(i_epoch))
        n_correct, n_total, loss_total = 0, 0, 0

        for i_batch, batch in enumerate(train_data_loader):
            model.train()
            global_step += 1

            inputs = [batch[col].to(device)   for col in inputs_cols]
            outputs = model(inputs)
            targets = batch['labels'].to(device)

            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            n_correct += (torch.argmax(outputs, -1) == targets).sum().item()
            n_total += len(outputs)
            loss_total += loss.item() * len(outputs)

            train_acc = n_correct / n_total
            train_loss = loss_total / n_total
            logger.info('loss: {:.4f}, acc: {:.4f}'.format(train_loss, train_acc))

            if global_step % val_step == 0:
                val_acc, val_f1,val_precision,val_recall = eval_(model, val_data_loader)
                logger.info('> max_val_f1: {:.4f}, max_val_acc: {:.4f}'.format(max_val_f1,max_val_acc))
                logger.info('> val_acc: {:.4f}, val_f1: {:.4f}, val_precision: {:.4f}, val_recall: {:.4f}'.format(val_acc,val_f1,val_precision,val_recall))

                if val_acc > max_val_acc:
                    max_val_f1 = val_f1
                    max_val_acc = val_acc
                    max_val_epoch = i_epoch
                    
                    torch.save(model.state_dict(), model_checkpoint)
                    logger.info(f'>> saved: {model_checkpoint}')

        if i_epoch - max_val_epoch >= 3:
            logger.info('>> early stop.')
            break

    model.load_state_dict(torch.load(model_checkpoint))
    model = model.to(device)

    test_acc, test_f1,test_precision,test_recall = eval_(model, test_data_loader, save_path=result_file)
    
    logger.info(f"{test_acc} {test_f1} {test_precision} {test_recall}")

    return (test_acc, test_f1,test_precision,test_recall)

In [6]:
def main():
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    model = CM_ATTENTION4(pretrained_bert_name, pretrained_vit_name).to(device)
    
    train(model, train_loader, valid_loader, test_loader)
    
    model.load_state_dict(torch.load(model_checkpoint))
    model = model.to(device)
    print(eval_(model, valid_loader, save_path=f'/root/results/{model_name}_val_predicts.txt'))
    print(eval_(model, test_loader, save_path=result_file))
    
main()

>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
epoch: 0
loss: 0.6931, acc: 0.6250
loss: 0.6942, acc: 0.5312
loss: 0.6936, acc: 0.6042
loss: 0.6912, acc: 0.6406
loss: 0.6926, acc: 0.6000
loss: 0.6930, acc: 0.5729
loss: 0.6891, acc: 0.6161
loss: 0.6856, acc: 0.6172
loss: 0.6766, acc: 0.6181
loss: 0.6743, acc: 0.6125
loss: 0.6605, acc: 0.6193
loss: 0.6513, acc: 0.6250
loss: 0.6363, acc: 0.6442
loss: 0.6246, acc: 0.6518
loss: 0.6104, acc: 0.6625
loss: 0.6183, acc: 0.6641
loss: 0.6306, acc: 0.6618
loss: 0.6383, acc: 0.6632
loss: 0.6269, acc: 0.6711
loss: 0.6231, acc: 0.6719
loss: 0.6184, acc: 0.6786
loss: 0.6065, acc: 0.6875
loss: 0.6050, acc: 0.6929
loss: 0.6192, acc: 0.6875
loss: 0.6165, acc: 0.6850
loss: 0.6166, acc: 0.6827
loss: 0.6170, acc: 0.6829
loss: 0.6156, acc: 0.6853
loss: 0.6146, acc: 0.6875
loss: 0.6164, acc: 0.6833
loss: 0.6163, acc: 0.6835
loss: 0.6161, acc: 0.6836
loss: 0.6155, acc: 0.6837
loss: 0.6197, acc: 0.6765
loss: 