In [None]:
from pathlib import Path
import os
from nltk.tokenize import word_tokenize
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from corpus import ColorsCorpusReader


## Prepare Data

In [None]:
def check_row_data(cols, contexts, id=0):
    col01,col02,col03 = cols[id]
    context = contexts[id]
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(4, 3))
    fig.suptitle(context)
    # plot correct color, col01
    ec = col01
    patch = mpatch.Rectangle((0, 0), 1, 1, color=col01, ec=ec, lw=8)
    axes[0].add_patch(patch)
    axes[0].axis('off')
    #axes[0].set_title(str(col01))
    # plot wrong color, col02
    ec = col02
    patch = mpatch.Rectangle((0, 0), 1, 1, color=col02, ec=ec, lw=8)
    axes[1].add_patch(patch)
    axes[1].axis('off')
    #axes[1].set_title(str(col02))
    # plot wrong color, col02
    ec = "black"
    patch = mpatch.Rectangle((0, 0), 1, 1, color=col03, ec=ec, lw=8)
    axes[2].add_patch(patch)
    axes[2].axis('off')
    #axes[2].set_title(str(col03))
    plt.show()

In [None]:
def sentence2index(sentence):
    tokens = word_tokenize(sentence)
    return [vocab_dict[w] for w in tokens]

In [None]:
# prepare raw data
root = Path(os.path.abspath('')).parent.parent.absolute()
data_path = os.path.join(root,"data")
print(data_path)
corpus = ColorsCorpusReader(os.path.join(data_path,"colors.csv"), word_count=None, normalize_colors=True)
examples = list(corpus.read())
print("Number of datapoints: {}".format(len(examples)))
# balance positive and negative samples
colors_data = [e.get_context_data()[0] for e in examples]
utterance_data = [e.get_context_data()[1] for e in examples]

In [None]:
check_row_data(colors_data,utterance_data,id=4)

### Generating vocab

In [None]:
from functools import reduce
import pickle
# generate vocab dict
if not os.path.exists("vocab.pkl"):
    print("Generating vocab dict ...")
    vocab_list = list(set(reduce(lambda x,y:x+y,[word_tokenize(c) for c in utterance_data]))) # with nltk.tokenizer, 3953 vocabs
    vocab_list = ["<pad>","<sos>","<eos>","<unk>"] + vocab_list                               # Added padding for batching
    vocab_dict = dict(zip(vocab_list,list(range(len(vocab_list)))))
    with open('vocab.pkl', 'wb') as f:
        pickle.dump(vocab_dict, f)
else:
    print("Loading vocab dict ...")
    with open('vocab.pkl', 'rb') as f:
        vocab_dict = pickle.load(f)
print("Length of the Vocab list is ",len(vocab_dict.keys()))
print("<pad> id = ",vocab_dict["<pad>"])
print("<sos> id = ",vocab_dict["<sos>"])
print("<eos> id = ",vocab_dict["<eos>"])
print("<unk> id = ",vocab_dict["<unk>"])
print("blue id = ",vocab_dict["blue"])
print("red id = ",vocab_dict["red"])
print("green id = ",vocab_dict["green"])

In [None]:
PAD = 0
SOS = 1
EOS = 2
UNK = 3

In [None]:
w2i = vocab_dict
i2w = {k:v for (v,k) in vocab_dict.items()}

### Prepare the Data loader

In [None]:
# Batching
colors_data_tensor = torch.tensor(np.array(colors_data),dtype=torch.float)
context_id_data = list(map(sentence2index,utterance_data))
max_context_len = max([len(c) for c in context_id_data])
padded_context_data = torch.tensor(np.array([[1]+c+[2]+[0]*(max_context_len-len(c)) for c in context_id_data]))   # <sos>+context+<eos>+<pad>*
print("Colors shape = ",colors_data_tensor.shape)
print("Padded context id lists shape = ",padded_context_data.shape)

data = [(color,torch.tensor(context,dtype=torch.long)) for color,context in zip(colors_data_tensor,padded_context_data)]
label = torch.zeros(len(data),3)
label[:,2] = 1.0
print("total data length = ",len(data))
print("total label shape = ",label.shape)

test_split = 7000
train_data, test_data = data[:-test_split], data[-test_split:]
train_label, test_label = label[:-test_split], label[-test_split:]
print("Train, Test data length = ",len(train_data),",",len(test_data))
print("Train, Test label length = ",len(train_label),",",len(test_label))

train_dataset = list(zip(train_data,train_label))
test_dataset = list(zip(test_data,test_label))
train_batch = DataLoader(dataset=train_dataset,batch_size=128,shuffle=True,num_workers=0)
test_batch = DataLoader(dataset=test_dataset,batch_size=128,shuffle=False,num_workers=0)

In [None]:
for (cols, lang), label in train_batch:
    print(cols.shape)
    print(lang.shape)
    print(label.shape)
    break

## Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
def to_onehot(y, n):
    y_onehot = torch.zeros(y.shape[0], n).to(y.device)
    y_onehot.scatter_(1, y.view(-1, 1), 1)
    return y_onehot

In [None]:
# For permutation invariance
class Colors_DeepSet(nn.Module):
    def __init__(self, input_size=3, output_size=16):
        super(Colors_DeepSet, self).__init__()
        self.linear1 = nn.Linear(input_size, output_size)
        self.linear2 = nn.Linear(output_size, output_size)
    
    def forward(self, col1, col2):
        col_emb1 = F.relu(self.linear1(col1))
        col_emb2 = F.relu(self.linear1(col2))
        col_embs = col_emb1 + col_emb2
        col_embs = self.linear2(col_embs)
        return col_embs

In [None]:
class Colors_Feature(nn.Module):
    def __init__(self, input_size=3, output_size =16):
        super(Colors_Feature, self).__init__()
        self.deepset = Colors_DeepSet(input_size, output_size)
        self.linear = nn.Linear(input_size+output_size, output_size)

    def forward(self,feats,labels):
        idxs = [0,1,2]
        target_idx = int(torch.argmax(labels))
        idxs.remove(target_idx)
        other_idx1,other_idx2 = idxs[0],idxs[1]
        target_col,col1,col2 = feats[:,target_idx], feats[:,other_idx1], feats[:,other_idx2]
        col_embs = F.relu(self.deepset(col1,col2))
        cols = torch.hstack((target_col,col_embs))
        feat = self.linear(cols)
        return feat

In [None]:
class Speaker(nn.Module):
    def __init__(self, feat_model, embedding_module, feat_size=16, hidden_size=100):
        super(Speaker, self).__init__()
        self.embedding = embedding_module
        self.feat_model = feat_model
        self.embedding_dim = embedding_module.embedding_dim
        self.vocab_size = embedding_module.num_embeddings
        self.hidden_size = hidden_size
        self.gru = nn.GRU(self.embedding_dim, self.hidden_size, num_layers=2, bidirectional=True)
        self.outputs2vocab = nn.Linear(self.hidden_size*2, self.vocab_size)                             # *2 for bidirectioanl
        self.init_h = nn.Linear(feat_size, self.hidden_size)

    def forward(self, feats, labels, tau=1, length_penalty=False, max_len=10):
        batch_size = feats.size(0)

        feats_emb = self.feat_model(feats, labels)
        
        # initialize hidden states using image features
        states = self.init_h(feats_emb).unsqueeze(0)
        states = states.repeat(2*2,1,1)
        
        # This contains are series of sampled onehot vectors
        lang,eos_prob,lang_prob = [], [], []
        lang_length = torch.ones(batch_size, dtype=torch.int64).to(feats.device)
        done_sampling = [False for _ in range(batch_size)]

        # first input is SOS token
        inputs_onehot = torch.zeros(batch_size, self.vocab_size).to(feats.device)   # (batch_size, n_vocab)
        inputs_onehot[:, SOS] = 1.0
        inputs_onehot = inputs_onehot.unsqueeze(1)                                  # (batch_size, len, n_vocab)
        lang.append(inputs_onehot)
        
        inputs_onehot = inputs_onehot.transpose(0, 1)                               # (B,L,D) to (L,B,D)
        inputs = inputs_onehot @ self.embedding.weight                              # (batch_size, 1, n_vocab) X (n_vocab, h) -> (batch_size, 1, h)

        for i in range(max_len - 2):  # Have room for SOS, EOS if never sampled
            self.gru.flatten_parameters()
            #print(inputs.shape)
            #print(states.shape)
            outputs, states = self.gru(inputs, states)  # outputs: (L=1,B,H)
            outputs = outputs.squeeze()                 # outputs: (B,H)
            outputs = self.outputs2vocab(outputs)       # outputs: (B,V)
            predicted_onehot = F.gumbel_softmax(outputs, tau=tau, hard=True)    # (B,V)
            lang.append(predicted_onehot.unsqueeze(1))
            
            if length_penalty:
                idx_prob = F.log_softmax(outputs, dim = 1)
                eos_prob.append(idx_prob[:,EOS])

            predicted_npy = predicted_onehot.argmax(1).cpu().numpy()            # (B,1)
            
            # Update language lengths
            for j, pred in enumerate(predicted_npy):
                if not done_sampling[j]: lang_length[j] += 1
                if pred == EOS: done_sampling[j] = True

            inputs = (predicted_onehot.unsqueeze(0)) @ self.embedding.weight    # (1, batch_size, n_vocab) X (n_vocab, h) -> (1, batch_size, h)
            
        # Add EOS if we've never sampled it
        eos_onehot = torch.zeros(batch_size, 1, self.vocab_size).to(feats.device)
        eos_onehot[:, 0, EOS] = 1.0
        lang.append(eos_onehot)
        
        for i, _ in enumerate(predicted_npy):               #predicted_npy: (B,1)
            if not done_sampling[i]: lang_length[i] += 1
            done_sampling[i] = True

        # Cat language tensors
        lang_tensor = torch.cat(lang, 1)                    # (B,max_L,V)
        
        if length_penalty:
            # eos prob -> eos loss
            eos_prob = torch.stack(eos_prob, dim = 1)
            for i in range(eos_prob.shape[0]):
                r_len = torch.arange(1,eos_prob.shape[1]+1,dtype=torch.float32)
                eos_prob[i] = eos_prob[i]*r_len.to(eos_prob.device)
                eos_prob[i, lang_length[i]:] = 0
            eos_loss = -eos_prob
            eos_loss = eos_loss.sum(1)/lang_length.float()
            eos_loss = eos_loss.mean()
        else:
            eos_loss = 0
            
        # Sum up log probabilities of samples
        return lang_tensor, lang_length, eos_loss, lang_prob

    def to_text(self, lang_onehot):
        texts = []
        lang = lang_onehot.argmax(2)
        for sample in lang.cpu().numpy():
            text = []
            for item in sample:
                text.append(data.ITOS[item])
                if item == data.EOS_IDX:
                    break
            texts.append(' '.join(text))
        return np.array(texts, dtype=np.unicode_)

### Model pipeline check

In [None]:
from literal_listener_color import SimpleBaseLine_L0

emb_dim = 768
speaker_embs = nn.Embedding(len(vocab_dict), emb_dim)
speaker_feat = Colors_Feature(output_size=16)
speaker = Speaker(speaker_feat, speaker_embs)
speaker.to(device)

literal_listener = SimpleBaseLine_L0(len(vocab_dict)).to(device)
literal_listener.load_state_dict(torch.load("model_params/baseline_fixed-vocab_l0.pth",map_location=device))

max_len = max_context_len+2  # <sos> + max_len + <eos>
max_len = 10

optimizer = optim.Adam(list(speaker.parameters()),lr=0.001)
criterion = nn.CrossEntropyLoss()

losss = []
accs = []

for i,((cols,lang),label) in enumerate(train_batch):
    cols, lang = cols.to(device), lang.to(device)
    label = label.to(device)
    optimizer.zero_grad()
    # freeze the literal listener weights
    #for param in literal_listener.parameters():
    #    param.requires_grad = False
    # literal_listener.train()
    literal_listener.eval()
    speaker.train()
    lang_tensor,lang_length,eos_loss,lang_prob = speaker(cols, label, tau=1, length_penalty=False, max_len=max_len)
    #print(lang_tensor.shape)
    #print(lang_length.shape)
    #print(eos_loss.item())
    #print(lang_prob.shape)
    #print(lang.shape)
    #lang = torch.vstack(tuple([to_onehot(sent.to(torch.int64) ,len(vocab_dict.keys())).unsqueeze(0) for sent in lang]))
    #print(lang.shape)
    # Important!! Do not use argmax which is not differentiable 
    output_lang1 = lang_tensor.argmax(2)
    args = torch.tensor(np.array(list(range(lang_tensor.size(2))))).to(device).to(torch.float)
    output_lang2 = torch.einsum('ijh,h->ij', (lang_tensor, args)).to(torch.int)
    #print(output_lang1[:5])
    #print(output_lang2[:5])
    #print(all((output_lang1==output_lang2).reshape(-1)))
    lis_scores01 = literal_listener(cols[:,0], output_lang2)
    lis_scores02 = literal_listener(cols[:,1], output_lang2)
    lis_scores03 = literal_listener(cols[:,2], output_lang2)
    lis_labels = torch.hstack((lis_scores01,lis_scores02,lis_scores03))
    loss = criterion(lis_labels,label) + eos_loss*0.0001
    #loss = criterion(lang_tensor.to(torch.float),lang.to(torch.float))
    loss.backward()
    optimizer.step()
    if i%10 == 0:
        print("\nOriginal sentence:\n"+" ".join([i2w[idx] for idx in lang[0].to("cpu").tolist()]).replace(" <pad>",""))
        print("Generated sentence:\n"+" ".join([i2w[idx] for idx in lang_tensor.argmax(2)[0].to("cpu").tolist()]).replace(" <pad>","")+"\n")
        print("Loss: ",loss.item())
        losss.append(loss.item())
        pred_labels = torch.argmax(lis_labels,dim=1)
        correct_labels = torch.zeros(cols.shape[0])+2
        acc = sum(correct_labels.to(device)==pred_labels)/len(correct_labels)
        print("Accuracy:",acc)
        accs.append(acc.item())
    if i > 100: break


## Trainig the model

### Setting

In [None]:
from literal_listener_color import SimpleBaseLine_L0

emb_dim = 768
speaker_embs = nn.Embedding(len(vocab_dict), emb_dim)
speaker_feat = Colors_Feature(output_size=16)
speaker = Speaker(speaker_feat, speaker_embs)
speaker.to(device)

literal_listener = SimpleBaseLine_L0(len(vocab_dict)).to(device)
literal_listener.load_state_dict(torch.load("model_params/baseline_fixed-vocab_l0.pth",map_location=device))

#max_len = max_context_len+2  # <sos> + max_len + <eos>
max_len = 5

optimizer = optim.Adam(list(speaker.parameters()),lr=0.001)
criterion = nn.CrossEntropyLoss()

epoch = 10

In [None]:
def check_data(cols, label, c_lang, g_lang, lis_label):
    col01,col02,col03 = cols[0].numpy(), cols[1].numpy(), cols[2].numpy()
    context = "Correct: "+" ".join(c_lang)+"\nGenerated:"+" ".join(g_lang)+"\n "
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(4, 6))
    fig.suptitle(context)
    # plot correct color, col01
    ec = "black" if label[0] > 0.5 else col01
    patch = mpatch.Rectangle((0, 0), 1, 1, color=col01, ec=ec, lw=8)
    axes[0].add_patch(patch)
    axes[0].axis('off')
    if torch.argmax(lis_label)==0:axes[0].set_title("Listener Prediction")
    # plot wrong color, col02
    ec = "black" if label[1] > 0.5 else col02
    patch = mpatch.Rectangle((0, 0), 1, 1, color=col02, ec=ec, lw=8)
    axes[1].add_patch(patch)
    axes[1].axis('off')
    if torch.argmax(lis_label)==1:axes[1].set_title("Listener Prediction")
    # plot wrong color, col02
    ec = "black" if label[2] > 0.5 else col03
    patch = mpatch.Rectangle((0, 0), 1, 1, color=col03, ec=ec, lw=8)
    axes[2].add_patch(patch)
    axes[2].axis('off')
    if torch.argmax(lis_label)==2:axes[2].set_title("Listener Prediction")
    plt.show()

### Start

In [None]:
train_loss_list = []
test_loss_list = []
train_acc_list = []
test_acc_list = []
best_loss = 100
best_acc = 0
for i in range(epoch):
    print("##############################################")
    print("Epoch:{}/{}".format(i+1,epoch))
    train_loss = 0
    test_loss = 0
    train_acc = 0
    test_acc = 0

    literal_listener.train()
    speaker.train()
    for j, ((cols,lang),label) in enumerate(train_batch):
        cols, lang, label = cols.to(device), lang.to(device), label.to(device)
        optimizer.zero_grad()
        lang_tensor,lang_length,eos_loss,lang_prob = speaker(cols, label, tau=1, length_penalty=False, max_len=max_len)
        # for L0 loss
        #output_lang = lang_tensor.argmax(2)
        args = torch.tensor(np.array(list(range(lang_tensor.size(2))))).to(device).to(torch.float)
        output_lang = torch.einsum('ijh,h->ij', (lang_tensor, args)).to(torch.int)
        lis_scores01 = literal_listener(cols[:,0], output_lang)
        lis_scores02 = literal_listener(cols[:,1], output_lang)
        lis_scores03 = literal_listener(cols[:,2], output_lang)
        lis_labels = F.softmax(torch.hstack((lis_scores01,lis_scores02,lis_scores03)),dim=-1)
        lis_loss = criterion(lis_labels,label)
        loss = lis_loss + eos_loss*0.0001
        loss.backward(retain_graph=True)
        optimizer.step()
        train_loss += loss.item()
        pred_labels = torch.argmax(lis_labels,dim=1)
        correct_labels = torch.zeros(pred_labels.shape[0])+2
        train_acc += (sum(correct_labels.to(device)==pred_labels)/len(correct_labels)).item()
    batch_train_loss = train_loss/len(train_batch)
    batch_train_acc = train_acc/len(train_batch)

    speaker.eval()
    with torch.no_grad():
        for (cols,lang),label in test_batch:
            cols, lang, label = cols.to(device), lang.to(device), label.to(device)
            lang_tensor,lang_length,eos_loss,lang_prob = speaker(cols, label, tau=1, length_penalty=False, max_len=max_len)
            #output_lang = lang_tensor.argmax(2)
            args = torch.tensor(np.array(list(range(lang_tensor.size(2))))).to(device).to(torch.float)
            output_lang = torch.einsum('ijh,h->ij', (lang_tensor, args)).to(torch.int)
            lis_scores01 = literal_listener(cols[:,0], output_lang)
            lis_scores02 = literal_listener(cols[:,1], output_lang)
            lis_scores03 = literal_listener(cols[:,2], output_lang)
            lis_labels = F.softmax(torch.hstack((lis_scores01,lis_scores02,lis_scores03)),dim=-1)
            #print(cols.shape, lis_labels.shape, label.shape)
            loss = criterion(lis_labels,label) + eos_loss*0.0001
            test_loss += loss.item()
            pred_labels = torch.argmax(lis_labels,dim=1)
            correct_labels = torch.zeros(pred_labels.shape[0])+2
            test_acc += (sum(correct_labels.to(device)==pred_labels)/len(correct_labels)).item()
        batch_test_loss = test_loss/len(test_batch)
        batch_test_acc = test_acc/len(test_batch)

    print("Train Loss:{:.2E}, Test Loss:{:.2E}".format(batch_train_loss,batch_test_loss))
    print("Train Acc:{:.2E}, Test Acc:{:.2E}".format(batch_train_acc,batch_test_acc))
    train_loss_list.append(batch_train_loss)
    test_loss_list.append(batch_test_loss)
    train_acc_list.append(batch_train_acc)
    test_acc_list.append(batch_test_acc)
    if batch_test_loss < best_loss:
        print("Best loss saved ...")
        torch.save(speaker.to(device).state_dict(),"model_params/color_S1_lis=baseline_birnnX2_no-penalty_gumble-argmax_loss=L0_best_loss.pth")
        best_loss = batch_test_loss
    if batch_test_acc > best_acc:
        print("Best acc saved ...")
        torch.save(speaker.to(device).state_dict(),"model_params/color_S1_lis=baseline_birnnX2_no-penalty_gumble-argmax_loss=L0_best_acc.pth")
        best_acc = batch_test_acc
    if i%1 == 0:
        id = np.random.randint(len(cols))
        cols = cols[id].to("cpu")
        c_langs = [i2w[idx] for idx in lang[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]
        g_langs = [i2w[idx] for idx in lang_tensor.argmax(2)[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]
        #print(int(torch.argmax(lis_labels[0])))
        check_data(cols, label[id], c_langs, g_langs, lis_labels[id])


In [None]:
# visualization
plt.figure()
plt.title("Train and Test Loss")
plt.xlabel("epoch")
plt.ylabel("Literal_Listener_loss")
plt.plot(range(1,epoch+1),train_loss_list,"b-",label="train_loss")
plt.plot(range(1,epoch+1),test_loss_list,"r--",label="test_loss")
plt.legend()
plt.show()
train_acc_list = [float(acc) for acc in train_acc_list]
test_acc_list = [float(acc) for acc in test_acc_list]
plt.figure()
plt.title("Train and Test Accuracy")
plt.xlabel("epoch")
plt.ylabel("Accuracy")
plt.plot(range(1,epoch+1),train_acc_list,"b-",label="train_acc")
plt.plot(range(1,epoch+1),test_acc_list,"r--",label="test_acc")
plt.legend()
plt.show()

## Accuracy test

In [None]:
speaker.load_state_dict(torch.load("model_params\color_S1_lis=baseline_original_birnnX2_no-penalty_loss=L0_best_acc.pth",map_location=device))
speaker.to(device)

### L0 communication Accuracy

In [None]:
losss = []
accs = []
speaker.eval()

for i,((cols,lang),label) in enumerate(test_batch):
    cols, lang = cols.to(device), lang.to(device)
    label = label.to(device)
    literal_listener.eval()
    lang_tensor,lang_length,eos_loss,lang_prob = speaker(cols, label, length_penalty=False, max_len=max_len)
    output_lang = lang_tensor.argmax(2)
    lis_scores01 = literal_listener(cols[:,0], output_lang)
    lis_scores02 = literal_listener(cols[:,1], output_lang)
    lis_scores03 = literal_listener(cols[:,2], output_lang)
    lis_labels = F.softmax(torch.hstack((lis_scores01,lis_scores02,lis_scores03)))
    loss = criterion(lis_labels,label) + eos_loss*0.0001
    pred_labels = torch.argmax(lis_labels,dim=1)
    correct_labels = torch.zeros(cols.shape[0])+2
    acc = sum(correct_labels.to(device)==pred_labels)/len(correct_labels)
    #print("Accuracy:",acc)
    accs.append(acc.item())
    if i%10 == 0:
        id = np.random.randint(len(cols))
        cols = cols[id].to("cpu")
        c_langs = [i2w[idx] for idx in lang[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]
        g_langs = [i2w[idx] for idx in lang_tensor.argmax(2)[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]
        label = label[id]
        print(int(torch.argmax(lis_labels[id])))
        check_data(cols, label, c_langs, g_langs, lis_labels[id])
        
    #if i > 100: break

print("Accuracy:,",np.mean(accs))

### L1 accuracy

In [None]:
losss = []
accs = []
speaker.eval()

for i,((cols,lang),label) in enumerate(test_batch):
    cols, lang, label = cols.to(device).to(torch.float), lang.to(device), label.to(device).to(torch.float)
    # for 1st image
    label01 = torch.zeros_like(label)
    label01[:,0] = 1.0
    lang_tensor1,lang_length,eos_loss,lang_prob = speaker(cols, label01, length_penalty=False, max_len=max_len)
    output_lang1 = lang_tensor1.argmax(2)
    # for 2nd image
    label02 = torch.zeros_like(label)
    label02[:,1] = 1.0
    lang_tensor2,lang_length,eos_loss,lang_prob = speaker(cols, label02, length_penalty=False, max_len=max_len)
    output_lang2 = lang_tensor2.argmax(2)
    # for 3rd image
    label03 = torch.zeros_like(label)
    label03[:,2] = 1.0
    lang_tensor3,lang_length,eos_loss,lang_prob = speaker(cols, label03, length_penalty=False, max_len=max_len)
    output_lang3 = lang_tensor3.argmax(2)
    
    prob01 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_tensor1,lang))]
    prob01_sums = list(map(sum,prob01))
    prob02 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_tensor2,lang))]
    prob02_sums = list(map(sum,prob02))
    prob03 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_tensor3,lang))]
    prob03_sums = list(map(sum,prob03))
    probs = F.softmax(torch.tensor(np.array([prob01_sums,prob02_sums,prob03_sums])).transpose(0,1),dim=-1)
    #print(probs)
    loss = criterion(probs.to(device),label) + eos_loss*0.0001
    losss.append(loss.item())
    pred_labels = torch.argmax(probs,dim=1)
    correct_labels = torch.zeros(cols.shape[0])+2
    acc = sum(correct_labels==pred_labels)/len(correct_labels)
    accs.append(acc.item())
    
    if i%10 == 0:
        id = np.random.randint(len(cols))
        imgs = cols[id].to("cpu")
        c_langs = [i2w[idx] for idx in lang[id].to("cpu").tolist() if idx not in [EOS,SOS,PAD]]
        g_langs = [i2w[idx] for idx in output_lang1[id].to("cpu").tolist() if idx not in [EOS,SOS,PAD]] \
                + ["|"]+ [i2w[idx] for idx in output_lang2[id].to("cpu").tolist() if idx not in [EOS,SOS,PAD]] \
                + ["|"]+ [i2w[idx] for idx in output_lang3[id].to("cpu").tolist() if idx not in [EOS,SOS,PAD]]
        label = label[id]
        check_data(imgs, label, c_langs, g_langs, probs[id])
        print(torch.exp(probs[id]))
        print(torch.where(torch.exp(probs[id])>0.1,1,0))
        #print("Loss: ",loss.item())

print("Loss: ",np.mean(losss))
print("Accuracy: ",np.mean(accs))