In [None]:
from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from shapeworld_data import load_raw_data, get_vocab, ShapeWorld


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## Prepare Data

In [None]:
def check_raw_data(imgs, labels, langs, id=0):
    data = list(zip(imgs,labels,langs))
    img_list,label,lang = data[id]
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 2))
    fig.suptitle(" ".join(lang))
    for i,(l,img) in enumerate(zip(label,img_list)):
        img = img.transpose((2,1,0))
        axes[i].imshow(img)
        if l==1: axes[i].set_title("Correct")
    plt.show()

In [None]:
root = Path(os.path.abspath('')).parent.parent.absolute()
data_path = os.path.join(root,"data\shapeworld_np")
print(data_path)
data_list = os.listdir(data_path)
print(data_list)

### Generating vocab_dict 

In [None]:
vocab = get_vocab([os.path.join(data_path,d) for d in data_list])
print(vocab["w2i"])

COLOR = {"white":[1,0,0,0,0,0], "green":[0,1,0,0,0,0], "gray":[0,0,1,0,0,0], "yellow":[0,0,0,1,0,0], "red":[0,0,0,0,1,0], "blue":[0,0,0,0,0,1], "other":[0,0,0,0,0,0]}
SHAPE = {"shape":[0,0,0,0], "square":[1,0,0,0], "circle":[0,1,0,0], "rectangle":[0,0,1,0], "ellipse":[0,0,0,1]}

w2i = vocab["w2i"]
i2w = vocab["i2w"]

PAD = 0
SOS = 1
EOS = 2
UNK = 3

### Prepare the Data loader

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[0]))
print(d["imgs"].shape)
print(d["labels"].shape)
print(d["langs"].shape)
check_raw_data(d["imgs"],d["labels"],d["langs"])
train_batch = DataLoader(ShapeWorld(d, vocab), batch_size=32, shuffle=False)

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[1]))
print(d["imgs"].shape)
print(d["labels"].shape)
print(d["langs"].shape)
check_raw_data(d["imgs"],d["labels"],d["langs"],id=1)
test_batch = DataLoader(ShapeWorld(d, vocab), batch_size=32, shuffle=False)

In [None]:
for imgs, label, lang in train_batch:
    print(imgs.shape)
    print(lang.shape)
    print(label.shape)
    break

## Model

In [None]:
def to_onehot(y, n):
    y_onehot = torch.zeros(y.shape[0], n).to(y.device)
    y_onehot.scatter_(1, y.view(-1, 1), 1)
    return y_onehot

In [None]:
import re
def utter2tensor(utter):
    utter = re.sub(r"<PAD>|<sos>|<eos>|<UNK>","",utter)
    #print(utter)
    utters = utter.split(" ")
    if len(utters) == 1 and utters[0] in COLOR.keys():
        return torch.tensor(np.array(COLOR[utters[0]]+SHAPE["shape"]))
    elif len(utters) == 1 and utters[0] in SHAPE.keys():
        return torch.tensor(np.array(COLOR["other"]+SHAPE[utters[0]]))
    elif len(utters) == 2:
        return torch.tensor(np.array(COLOR[utters[0]]+SHAPE[utters[1]]))
    else:
        return torch.tensor(np.array(COLOR["other"]+SHAPE["shape"]))

In [None]:
class Imgs_emb_DeepSet(nn.Module):
    def __init__(self, input_size=1024, output_size=1024):
        super(Imgs_emb_DeepSet, self).__init__()
        self.linear1 = nn.Linear(input_size, output_size)
    
    def forward(self, img_emb1, img_emb2):
        img_embs = img_emb1 + img_emb2
        img_embs = self.linear1(img_embs)
        return img_embs

In [None]:
from vision import ConvNet

class Imgs_Feature(nn.Module):
    def __init__(self, input_size=1024, output_size =1024):
        super(Imgs_Feature, self).__init__()
        self.cnn_encoder = ConvNet(4)
        self.deepset = Imgs_emb_DeepSet(input_size, output_size)
        self.linear = nn.Linear(output_size+input_size, output_size)

    def forward(self,feats,labels):
        batch_size = feats.shape[0]
        n_obj = feats.shape[1]
        rest = feats.shape[2:]
        feats_flat = feats.reshape(batch_size * n_obj, *rest)
        feats_emb_flat = self.cnn_encoder(feats_flat)
        cnn_emb = feats_emb_flat.unsqueeze(1).view(batch_size, n_obj, -1)
        idxs = [0,1,2]
        target_idx = int(torch.argmax(labels))
        idxs.remove(target_idx)
        other_idx1,other_idx2 = idxs[0],idxs[1]
        target_img,img1,img2 = cnn_emb[:,target_idx,:], cnn_emb[:,other_idx1,:], cnn_emb[:,other_idx2,:]
        img_embs = F.relu(self.deepset(img1,img2))
        imgs = torch.hstack((img_embs,target_img))
        #print(cols.shape)
        feat = self.linear(imgs)
        return feat

In [None]:
# Original Feature model
from vision import ConvNet
from literal_listener_shapeworld import CNN_encoder

class Original_Imgs_emb_DeepSet(nn.Module):
    def __init__(self, input_size=1024, output_size=1024):
        super(Original_Imgs_emb_DeepSet, self).__init__()
        self.linear1 = nn.Linear(input_size, output_size)
    
    def forward(self, img_emb1, img_emb2):
        img_embs = img_emb1 + img_emb2
        img_embs = self.linear1(img_embs)
        return img_embs


class Original_CS_CNN_encoder_Feature(nn.Module):
    def __init__(self, input_size=10, output_size=10):
        super(Original_CS_CNN_encoder_Feature, self).__init__()
        self.cnn_color_encoder = CNN_encoder(6)
        self.cnn_color_encoder.load_state_dict(torch.load("model_params/shapeworld_original-cnn_color_model.pth",map_location=device))
        self.cnn_shape_encoder = CNN_encoder(4)
        self.cnn_shape_encoder.load_state_dict(torch.load("model_params/shapeworld_original-cnn_shape_model.pth",map_location=device))
        self.deepset_size = 6+4
        self.deepset = Original_Imgs_emb_DeepSet(self.deepset_size, self.deepset_size)
        self.linear = nn.Linear(output_size+input_size, output_size)

    def get_feat_emb(self,feat):
        col_embs = self.cnn_color_encoder(feat)
        shape_embs = self.cnn_shape_encoder(feat)
        img_embs = torch.hstack((col_embs,shape_embs))
        return img_embs
    
    def forward(self,feats,labels):
        idxs = [0,1,2]
        target_idx = int(torch.argmax(labels))
        idxs.remove(target_idx)
        other_idx1,other_idx2 = idxs[0],idxs[1]
        target_img,other_img1,other_img2 = feats[:,target_idx], feats[:,other_idx1], feats[:,other_idx2]
        target_embs = self.get_feat_emb(target_img)
        other_embs1 = self.get_feat_emb(other_img1)
        other_embs2 = self.get_feat_emb(other_img2)                 # (batch_size,10)
        other_embs = F.relu(self.deepset(other_embs1,other_embs2))  # (batch_size,10)
        embs = torch.hstack((target_embs,other_embs))               # (batch_size,20)
        feat = self.linear(embs)
        return feat

In [None]:
from literal_listener_shapeworld import CNN_encoder

class CS_CNN_encoder_Feature(nn.Module):
    def __init__(self, input_size=10, output_size=10):
        super(CS_CNN_encoder_Feature, self).__init__()
        self.cnn_color_encoder = CNN_encoder(6)
        self.cnn_color_encoder.load_state_dict(torch.load("model_params/shapeworld_original-cnn_color_model.pth",map_location=device))
        self.cnn_shape_encoder = CNN_encoder(4)
        self.cnn_shape_encoder.load_state_dict(torch.load("model_params/shapeworld_original-cnn_shape_model.pth",map_location=device))
        self.deepset_input_size = 6+4
        self.deepset_output_size = (6+4)*1
        self.deepset = Imgs_emb_DeepSet(self.deepset_input_size, self.deepset_output_size)
        self.trans_target = nn.Linear(self.deepset_input_size, self.deepset_input_size)
        self.linear1 = nn.Linear(self.deepset_output_size+self.deepset_input_size, output_size)
        self.linear2 = nn.Linear(output_size, output_size)
        self.linear3 = nn.Linear(output_size, output_size)

    def get_feat_emb(self,feat):
        col_embs = self.cnn_color_encoder(feat)
        shape_embs = self.cnn_shape_encoder(feat)
        img_embs = torch.hstack((col_embs,shape_embs))
        return img_embs
    
    def forward(self,feats,labels):
        idxs = [0,1,2]
        target_idx = int(torch.argmax(labels[0]))
        idxs.remove(target_idx)
        other_idx1,other_idx2 = idxs[0],idxs[1]
        target_img,other_img1,other_img2 = feats[:,target_idx], feats[:,other_idx1], feats[:,other_idx2]
        target_embs = self.get_feat_emb(target_img)
        other_embs1 = self.get_feat_emb(other_img1)
        other_embs2 = self.get_feat_emb(other_img2)                 # (batch_size,10)
        #other_embs = self.deepset(other_embs1,other_embs2)          # (batch_size,10)
        other_embs = other_embs1+other_embs2
        embs = torch.hstack((target_embs,other_embs))               # (batch_size,20)
        return embs

In [None]:
from literal_listener_shapeworld import CNN_encoder

class CS_CNN_no_deepset_Feature(nn.Module):
    def __init__(self, input_size=10, output_size=10):
        super(CS_CNN_no_deepset_Feature, self).__init__()
        self.cnn_color_encoder = CNN_encoder(6)
        self.cnn_color_encoder.load_state_dict(torch.load("model_params/shapeworld_original-cnn_color_model.pth",map_location=device))
        self.cnn_shape_encoder = CNN_encoder(4)
        self.cnn_shape_encoder.load_state_dict(torch.load("model_params/shapeworld_original-cnn_shape_model.pth",map_location=device))
        self.deepset_size = 6+4
        #self.deepset = Imgs_emb_DeepSet(self.deepset_size, self.deepset_size)
        self.linear = nn.Linear(input_size*3, output_size)

    def get_feat_emb(self,feat):
        col_embs = self.cnn_color_encoder(feat)
        shape_embs = self.cnn_shape_encoder(feat)
        img_embs = torch.hstack((col_embs,shape_embs))
        return img_embs
    
    def forward(self,feats,labels):
        idxs = [0,1,2]
        target_idx = int(torch.argmax(labels))
        idxs.remove(target_idx)
        other_idx1,other_idx2 = idxs[0],idxs[1]
        target_img,other_img1,other_img2 = feats[:,target_idx], feats[:,other_idx1], feats[:,other_idx2]
        target_embs = self.get_feat_emb(target_img)
        other_embs1 = self.get_feat_emb(other_img1)
        other_embs2 = self.get_feat_emb(other_img2)                 # (batch_size,10)
        #other_embs = F.relu(self.deepset(other_embs1,other_embs2))  # (batch_size,10)
        embs = torch.hstack((target_embs,other_embs1,other_embs2))  # (batch_size,30)
        feat = self.linear(embs)
        return feat

In [None]:
class Speaker(nn.Module):
    def __init__(self, feat_model, embedding_module, feat_size=1024, hidden_size=100):
        super(Speaker, self).__init__()
        self.embedding = embedding_module
        self.feat_model = feat_model
        self.embedding_dim = embedding_module.embedding_dim
        self.vocab_size = embedding_module.num_embeddings
        self.hidden_size = hidden_size
        self.gru = nn.GRU(self.embedding_dim, self.hidden_size, num_layers=3, bidirectional=True)
        self.outputs2vocab = nn.Linear(self.hidden_size*2, self.vocab_size)                             # *2 for bidirectioanl
        self.init_h = nn.Linear(feat_size, self.hidden_size)

    def forward(self, feats, labels, tau=10, length_penalty=False, max_len=40):
        batch_size = feats.size(0)
        
        # initialize hidden states using image features
        feats_emb = self.feat_model(feats,labels)
        states = self.init_h(feats_emb).unsqueeze(0)
        states = states.repeat(2*3,1,1)
        
        # This contains are series of sampled onehot vectors
        lang = []
        lang_length = torch.ones(batch_size, dtype=torch.int64).to(feats.device)
        
        # first input is SOS token
        inputs_onehot = torch.zeros(batch_size, self.vocab_size).to(feats.device)   # (batch_size, n_vocab)
        inputs_onehot[:, SOS] = 1.0
        inputs_onehot = inputs_onehot.unsqueeze(1)                                  # (batch_size, 1, n_vocab)
        lang.append(inputs_onehot)                                                  # Add SOS to lang
        
        inputs_onehot = inputs_onehot.transpose(0, 1)                               # (B,L,D) to (L,B,D)
        inputs = inputs_onehot @ self.embedding.weight                              # (1,batch_size, n_vocab) X (n_vocab, h) -> (1,batch_size, h)
        
        for i in range(max_len - 2):  # Have room for SOS, EOS if never sampled
            self.gru.flatten_parameters()
            #print(inputs.shape)
            #print(states.shape)
            outputs, states = self.gru(inputs, states)                          # outputs: (L=1,B,H)
            outputs = outputs.squeeze()                                         # outputs: (B,H)
            outputs = self.outputs2vocab(outputs)                               # outputs: (B,V)
            predicted_onehot = F.gumbel_softmax(outputs, tau=tau, hard=True)    # (B,V)
            lang.append(predicted_onehot.unsqueeze(1))                          # Add to lang
            inputs = (predicted_onehot.unsqueeze(0)) @ self.embedding.weight    # (1, batch_size, n_vocab) X (n_vocab, h) -> (1, batch_size, h)
            
        # Add EOS if we've never sampled it
        eos_onehot = torch.zeros(batch_size, 1, self.vocab_size).to(feats.device)
        eos_onehot[:, 0, EOS] = 1.0
        lang.append(eos_onehot)

        # Cat language tensors
        lang_tensor = torch.cat(lang, 1)                    # (B,max_L,V)
            
        # Sum up log probabilities of samples
        return lang_tensor, lang_length, 0, []

## Trainig the model

### helper functions

In [None]:
def langs2tensor(langs):
    col_shapes = torch.zeros(langs.size(0),10)
    for i,lang in enumerate(langs):
        col_shape = lang2tensor(lang)
        col_shapes[i] = col_shape
    return col_shapes

def lang2tensor(lang):
    utter = " ".join([i2w[int(i)] for i in lang])
    utter = re.sub(r"<PAD>|<sos>|<eos>","",utter)
    utters = [u for u in utter.split(" ") if u]
    #print(utters)
    if len(utters) == 1 and utters[0] in COLOR.keys():
        return torch.tensor(np.array(COLOR[utters[0]]+SHAPE["shape"]))
    elif len(utters) == 1 and utters[0] in SHAPE.keys():
        return torch.tensor(np.array(COLOR["other"]+SHAPE[utters[0]]))
    elif len(utters) == 2:
        return torch.tensor(np.array(COLOR.get(utters[0],COLOR["other"])+SHAPE.get(utters[1],SHAPE["shape"])))
    else:
        return torch.tensor(np.array(COLOR["other"]+SHAPE["shape"]))

### training

In [None]:
from literal_listener_shapeworld import ShapeWorld_RNN_L0

emb_dim = 768
feat_dim = 10
speaker_embs = nn.Embedding(len(w2i), emb_dim)
speaker_feat = Original_CS_CNN_encoder_Feature(output_size=feat_dim)
speaker = Speaker(speaker_feat, speaker_embs, feat_size=feat_dim)
speaker.to(device)

literal_listener = ShapeWorld_RNN_L0(len(w2i)).to(device)
literal_listener.load_state_dict(torch.load("model_params\shapeworld_rnn_full-data_100epoch_l0_last.pth",map_location=device))

max_len = 4

optimizer = optim.Adam(list(speaker.parameters())+list(literal_listener.parameters()),lr=0.01)
criterion = nn.CrossEntropyLoss()

epoch = 5

best_loss_pth = "model_params/shapeworld_S1_lis=emb-rnn_CS-CNN-more-linear_no-penalty_col-shape-loss_best-loss.pth"
best_acc_pth = "model_params/shapeworld_S1_lis=emb-rnn_CS-CNN-more-linear_no-penalty_col-shape-loss_best-acc.pth"

In [None]:
train_loss_list = []
test_loss_list = []
train_acc_list = []
test_acc_list = []
best_loss = 100
best_acc = 0
for i in range(epoch):
    print("##############################################")
    print("Epoch:{}/{}".format(i+1,epoch))
    train_loss = 0
    test_loss = 0
    train_acc = 0
    test_acc = 0
    #for params in literal_listener.parameters(): params.requires_grad = False
    literal_listener.train()
    speaker.train()
    #print("Start Training")
    for cols,label,lang in train_batch:
        cols, lang = cols.to(device).to(torch.float), lang.to(device)
        label = label.to(device).to(torch.float)
        optimizer.zero_grad()
        lang_tensor,lang_length,eos_loss,lang_prob = speaker(cols, label, length_penalty=False, max_len=max_len)
        #print(lang_tensor.shape)
        output_lang = lang_tensor.argmax(2)
        # for color and shape loss
        correct_colshape = langs2tensor(lang)
        output_colshape = langs2tensor(output_lang)
        #print(correct_colshape.shape,output_colshape.shape)
        correct_col, correct_shape = correct_colshape[:,:6], correct_colshape[:,6:]
        output_col, output_shape = output_colshape[:,:6], output_colshape[:,6:]
        col_loss = criterion(output_col,correct_col)
        shape_loss = criterion(output_shape,correct_shape)
        # for l0 loss
        lis_labels = literal_listener(cols, output_lang)
        lis_loss = criterion(lis_labels,label)
        #lis_loss.requires_grad= True
        loss = lis_loss*1.0 + col_loss*1.0 + shape_loss*1.0 + eos_loss*0.0001
        lis_loss.backward()
        #print([x.grad.data for x in speaker.parameters()])
        optimizer.step()
        train_loss += loss.item()
        pred_labels = torch.argmax(lis_labels,dim=1)
        correct_labels = torch.zeros(pred_labels.shape[0])+2
        train_acc += sum(correct_labels.to(device)==pred_labels)/len(correct_labels)
    batch_train_loss = train_loss/len(train_batch)
    batch_train_acc = train_acc/len(train_batch)

    speaker.eval()
    #print("Start Evaluation")
    for cols,label,lang in test_batch:
        cols, lang = cols.to(device).to(torch.float), lang.to(device)
        label = label.to(device).to(torch.float)
        lang_tensor,lang_length,eos_loss,lang_prob = speaker(cols, label, length_penalty=False, max_len=max_len)
        output_lang = lang_tensor.argmax(2)
        # for color and shape loss
        correct_colshape = langs2tensor(lang)
        output_colshape = langs2tensor(output_lang)
        #print(correct_colshape.shape,output_colshape.shape)
        correct_col, correct_shape = correct_colshape[:,:6], correct_colshape[:,6:]
        output_col, output_shape = output_colshape[:,:6], output_colshape[:,6:]
        col_loss = criterion(output_col,correct_col)
        shape_loss = criterion(output_shape,correct_shape)
        # for l0 loss
        lis_labels = literal_listener(cols, output_lang)
        lis_loss = criterion(lis_labels,label)
        #lis_loss.requires_grad= True
        loss = lis_loss*1.0 + col_loss*1.0 + shape_loss*1.0 + eos_loss*0.0001
        test_loss += loss.item()
        pred_labels = torch.argmax(lis_labels,dim=1)
        correct_labels = torch.zeros(pred_labels.shape[0])+2
        test_acc += (sum(correct_labels.to(device)==pred_labels)/len(correct_labels)).item()
    batch_test_loss = test_loss/len(test_batch)
    batch_test_acc = test_acc/len(test_batch)

    print("Train Loss:{:.2E}, Test Loss:{:.2E}".format(batch_train_loss,batch_test_loss))
    print("Train Acc:{:.2E}, Test Acc:{:.2E}".format(batch_train_acc,batch_test_acc))
    train_loss_list.append(batch_train_loss)
    test_loss_list.append(batch_test_loss)
    train_acc_list.append(batch_train_acc)
    test_acc_list.append(batch_test_acc)
    if batch_test_loss < best_loss:
        print("Best loss saved ... ")
        torch.save(speaker.to(device).state_dict(),best_loss_pth)
        best_loss = batch_test_loss
    if batch_test_acc > best_acc:
        print("Best acc saved ... ")
        torch.save(speaker.to(device).state_dict(),best_acc_pth)
        best_acc = batch_test_acc


In [None]:
# visualization
plt.figure()
plt.title("Train and Test Loss")
plt.xlabel("epoch")
plt.ylabel("Literal_Listener_loss")
plt.plot(range(1,epoch+1),train_loss_list,"b-",label="train_loss")
plt.plot(range(1,epoch+1),test_loss_list,"r--",label="test_loss")
plt.legend()
plt.show()
train_acc_list = [float(acc) for acc in train_acc_list]
test_acc_list = [float(acc) for acc in test_acc_list]
plt.figure()
plt.title("Train and Test Accuracy")
plt.xlabel("epoch")
plt.ylabel("Accuracy")
plt.plot(range(1,epoch+1),train_acc_list,"b-",label="train_acc")
plt.plot(range(1,epoch+1),test_acc_list,"r--",label="test_acc")
plt.legend()
plt.show()

## Accuracy test

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[1]))
print(d["imgs"].shape)
print(d["labels"].shape)
print(d["langs"].shape)
check_raw_data(d["imgs"],d["labels"],d["langs"],id=1)
eval_batch = DataLoader(ShapeWorld(d, vocab), batch_size=32, shuffle=False)

In [None]:
import re
def check_data(imgs, label, g_langs, c_langs):
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 4))
    title = re.sub(r"<sos>|<eos>","","Generated: "+" ".join(g_langs)+"\n\nCorrect: "+" ".join(c_langs))
    fig.suptitle(title)
    for i,(l,img) in enumerate(zip(label,imgs)):
        img = img.transpose(2,0)
        axes[i].imshow(img)
        if l==1: axes[i].set_title("Correct")
    plt.show()

### L0 accuracy

In [None]:
#speaker_embs = nn.Embedding(len(w2i), emb_dim)
#speaker_feat = Imgs_Feature(output_size=1024)
#speaker = Speaker(speaker_feat, speaker_embs)
speaker.load_state_dict(torch.load(best_acc_pth,map_location=device))
speaker.to(device)
losss = []
accs = []
for i,(cols,label,lang) in enumerate(eval_batch):
    cols, lang = cols.to(device).to(torch.float), lang.to(device)
    label = label.to(device).to(torch.float)
    literal_listener.eval()
    speaker.eval()
    lang_tensor,lang_length,eos_loss,lang_prob = speaker(cols, label, length_penalty=False, max_len=max_len)
    output_lang = lang_tensor.argmax(2)
    lis_labels = literal_listener(cols, output_lang)
    loss = criterion(lis_labels,label) + eos_loss*0.0001
    if i%5 == 0:
        imgs = cols[0].to("cpu")
        c_langs = [i2w[idx] for idx in lang[0].to("cpu").tolist()]
        g_langs = [i2w[idx] for idx in lang_tensor.argmax(2)[0].to("cpu").tolist()]
        label = label[0]
        check_data(imgs, label, g_langs, c_langs)
        #print("Loss: ",loss.item())
        losss.append(loss.item())
        pred_labels = torch.argmax(lis_labels,dim=1)
        correct_labels = torch.zeros(cols.shape[0])
        acc = sum(correct_labels.to(device)==pred_labels)/len(correct_labels)
        #print("Accuracy:",acc)
        accs.append(acc.item())
    if i > 100: break

print("Accuracy:,",np.mean(accs))


### L1 accuracy

In [None]:
losss = []
accs = []
for i,(cols,label,lang) in enumerate(eval_batch):
    cols, lang = cols.to(device).to(torch.float), lang.to(device)
    label = label.to(device).to(torch.float)
    literal_listener.eval()
    speaker.eval()
    lang_tensor,lang_length,eos_loss,lang_prob = speaker(cols, label, length_penalty=False, max_len=max_len)
    output_lang1 = lang_tensor.argmax(2)
    # for 2nd image
    label02 = torch.zeros_like(label)
    label02[:,1] = 1.0
    lang_tensor1,lang_length,eos_loss,lang_prob = speaker(cols, label02, length_penalty=False, max_len=max_len)
    output_lang2 = lang_tensor1.argmax(2)
    # for 3rd image
    label03 = torch.zeros_like(label)
    label03[:,2] = 1.0
    lang_tensor2,lang_length,eos_loss,lang_prob = speaker(cols, label03, length_penalty=False, max_len=max_len)
    output_lang3 = lang_tensor2.argmax(2)
    
    prob01 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_tensor,output_lang3))]
    prob01_sums = list(map(sum,prob01))
    prob02 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_tensor1,output_lang3))]
    prob02_sums = list(map(sum,prob02))
    prob03 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_tensor2,output_lang3))]
    prob03_sums = list(map(sum,prob03))
    probs = torch.tensor(np.array([prob01_sums,prob02_sums,prob03_sums])).transpose(0,1)
    #print(probs)
    loss = criterion(probs.to(device),label) + eos_loss*0.0001
    losss.append(loss.item())
    pred_labels = torch.argmax(probs,dim=1)
    correct_labels = torch.zeros(cols.shape[0])+2
    acc = sum(correct_labels==pred_labels)/len(correct_labels)
    accs.append(acc.item())
    
    if i%10 == 0:
        imgs = cols[0].to("cpu")
        c_langs = [i2w[idx] for idx in lang[0].to("cpu").tolist()]
        g_langs = [i2w[idx] for idx in output_lang[0].to("cpu").tolist()] \
                + ["|"]+ [i2w[idx] for idx in output_lang2[0].to("cpu").tolist()] \
                + ["|"]+ [i2w[idx] for idx in output_lang3[0].to("cpu").tolist()]
        label = label[0]
        check_data(imgs, label, g_langs, c_langs)
        print(torch.exp(probs[0]))
        print(torch.where(torch.exp(probs[0])>0.1,1,0))
        #print("Loss: ",loss.item())
    if i > 100: break

print("Accuracy:,",np.mean(accs))