In [None]:
from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from shapeworld_data import load_raw_data, get_vocab, ShapeWorld

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## Prepare Data

In [None]:
def check_raw_data(imgs, labels, langs, id=0):
    data = list(zip(imgs,labels,langs))
    img_list,label,lang = data[id]
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 2))
    fig.suptitle(" ".join(lang))
    for i,(l,img) in enumerate(zip(label,img_list)):
        img = img.transpose((2,1,0))
        axes[i].imshow(img)
        if l==1: axes[i].set_title("Correct")
    plt.show()

In [None]:
root = Path(os.path.abspath('')).parent.parent.absolute()
data_path = os.path.join(root,"data\shapeworld_np")
print(data_path)
data_list = os.listdir(data_path)
print(data_list)

### Generating vocab_dict 

In [None]:
vocab = get_vocab([os.path.join(data_path,d) for d in data_list])
print(vocab["w2i"])

COLOR = {"white":[1,0,0,0,0,0], "green":[0,1,0,0,0,0], "gray":[0,0,1,0,0,0], "yellow":[0,0,0,1,0,0], "red":[0,0,0,0,1,0], "blue":[0,0,0,0,0,1], "other":[0,0,0,0,0,0]}
SHAPE = {"shape":[0,0,0,0], "square":[1,0,0,0], "circle":[0,1,0,0], "rectangle":[0,0,1,0], "ellipse":[0,0,0,1]}

w2i = vocab["w2i"]
i2w = vocab["i2w"]

PAD = 0
SOS = 1
EOS = 2
UNK = 3

### Prepare the Data loader

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[0]))
print(d["imgs"].shape)
print(d["labels"].shape)
print(d["langs"].shape)
check_raw_data(d["imgs"],d["labels"],d["langs"])
train_batch = DataLoader(ShapeWorld(d, vocab), batch_size=32, shuffle=False)

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[1]))
print(d["imgs"].shape)
print(d["labels"].shape)
print(d["langs"].shape)
check_raw_data(d["imgs"],d["labels"],d["langs"],id=1)
test_batch = DataLoader(ShapeWorld(d, vocab), batch_size=32, shuffle=False)

In [None]:
for imgs, label, lang in train_batch:
    print(imgs.shape)
    print(lang.shape)
    print(label.shape)
    break

## Model

In [None]:
def to_onehot(y, n=3):
    #print(y.shape)
    y_onehot = torch.zeros(y.shape[0], n).to(y.device)
    y_onehot.scatter_(1, y.to(torch.int64).view(-1, 1), 1)
    return y_onehot

In [None]:
class Imgs_emb_DeepSet(nn.Module):
    def __init__(self, input_size=1024, output_size=1024):
        super(Imgs_emb_DeepSet, self).__init__()
        self.linear1 = nn.Linear(input_size, output_size)
    
    def forward(self, img_emb1, img_emb2):
        img_embs = img_emb1 + img_emb2
        img_embs = self.linear1(img_embs)
        return img_embs

In [None]:
from vision import ConvNet
from literal_listener_shapeworld import CNN_encoder

class Imgs_Feature(nn.Module):
    def __init__(self, input_size=1024, output_size =1024):
        super(Imgs_Feature, self).__init__()
        self.cnn_encoder = ConvNet(4)
        self.deepset = Imgs_emb_DeepSet(input_size, output_size)
        self.linear = nn.Linear(output_size+input_size, output_size)

    def forward(self,feats):
        batch_size = feats.shape[0]
        n_obj = feats.shape[1]
        rest = feats.shape[2:]
        feats_flat = feats.reshape(batch_size * n_obj, *rest)
        feats_emb_flat = self.cnn_encoder(feats_flat)
        cnn_emb = feats_emb_flat.unsqueeze(1).view(batch_size, n_obj, -1)
        target_img,img1,img2 = cnn_emb[:,0,:], cnn_emb[:,1,:], cnn_emb[:,2,:]
        img_embs = F.relu(self.deepset(img1,img2))
        imgs = torch.hstack((img_embs,target_img))
        #print(cols.shape)
        feat = self.linear(imgs)
        return feat

In [None]:
class Speaker(nn.Module):
    def __init__(self, feat_model, embedding_module, feat_size=1024, hidden_size=100):
        super(Speaker, self).__init__()
        self.embedding = embedding_module
        self.feat_model = feat_model
        self.embedding_dim = embedding_module.embedding_dim
        self.vocab_size = embedding_module.num_embeddings
        self.hidden_size = hidden_size
        self.gru = nn.GRU(self.embedding_dim, self.hidden_size, bidirectional=False)
        self.outputs2vocab = nn.Linear(self.hidden_size*1, self.vocab_size)                             # *2 for bidirectioanl
        self.init_h = nn.Linear(feat_size, self.hidden_size)

    def forward(self, feats, tau=1, length_penalty=False, max_len=40):
        batch_size = feats.size(0)

        # initialize hidden states using image features
        feats_emb = self.feat_model(feats)
        states = self.init_h(feats_emb).unsqueeze(0)
        #states = torch.vstack((states, states))
        
        # This contains are series of sampled onehot vectors
        lang,eos_prob,lang_prob = [], [], []
        lang_length = torch.ones(batch_size, dtype=torch.int64).to(feats.device)
        done_sampling = [False for _ in range(batch_size)]

        # first input is SOS token
        inputs_onehot = torch.zeros(batch_size, self.vocab_size).to(feats.device)   # (batch_size, n_vocab)
        inputs_onehot[:, SOS] = 1.0
        inputs_onehot = inputs_onehot.unsqueeze(1)                                  # (batch_size, 1, n_vocab)
        lang.append(inputs_onehot)                                                  # Add SOS to lang
        
        inputs_onehot = inputs_onehot.transpose(0, 1)                               # (B,L,D) to (L,B,D)
        inputs = inputs_onehot @ self.embedding.weight                              # (1,batch_size, n_vocab) X (n_vocab, h) -> (1,batch_size, h)

        for i in range(max_len - 2):  # Have room for SOS, EOS if never sampled
            if all(done_sampling): break
            
            self.gru.flatten_parameters()
            #print(inputs.shape)
            #print(states.shape)
            outputs, states = self.gru(inputs, states)                          # outputs: (L=1,B,H)
            outputs = outputs.squeeze()                                         # outputs: (B,H)
            outputs = self.outputs2vocab(outputs)                               # outputs: (B,V)
            predicted_onehot = F.gumbel_softmax(outputs, tau=tau, hard=True)    # (B,V)
            lang.append(predicted_onehot.unsqueeze(1))                          # Add to lang
            
            if length_penalty:
                idx_prob = F.log_softmax(outputs, dim = 1)
                eos_prob.append(idx_prob[:,EOS])

            predicted_npy = predicted_onehot.argmax(1).cpu().numpy()            # (B,1)
            
            # Update language lengths
            for j, pred in enumerate(predicted_npy):
                if not done_sampling[j]: lang_length[j] += 1
                if pred == EOS: done_sampling[j] = True

            inputs = (predicted_onehot.unsqueeze(0)) @ self.embedding.weight    # (1, batch_size, n_vocab) X (n_vocab, h) -> (1, batch_size, h)
        
        
        self.gru.flatten_parameters()
        outputs, states = self.gru(inputs, states)  # outputs: (L=1,B,H)
        outputs = outputs.squeeze(0)                # outputs: (B,H)
        outputs = self.outputs2vocab(outputs)       # outputs: (B,V)
        idx_prob = F.log_softmax(outputs, dim=1)    # (B,V)
        lang_prob.append(idx_prob[:, EOS].unsqueeze(1))    # lang.append((B,1))
            
        # Add EOS if we've never sampled it
        eos_onehot = torch.zeros(batch_size, 1, self.vocab_size).to(feats.device)
        eos_onehot[:, 0, EOS] = 1.0
        lang.append(eos_onehot)
        for i, _ in enumerate(predicted_npy):               #predicted_npy: (B,1)
            if not done_sampling[i]: lang_length[i] += 1
            done_sampling[i] = True

        # Cat language tensors
        lang_tensor = torch.cat(lang, 1)                    # (B,max_L,V)
        # Trim max length
        for i in range(lang_tensor.shape[0]):
            lang_tensor[i, lang_length[i]:] = 0
        max_lang_len = max_len #lang_length.max()
        lang_tensor = lang_tensor[:, :max_lang_len, :]
        
        lang_prob_tensor = torch.cat(lang_prob, 1)              # (B,arbital_L)
        for i in range(lang_prob_tensor.shape[0]):
            lang_prob_tensor[i, lang_length[i]:] = 0
        lang_prob_tensor = lang_prob_tensor[:, :max_lang_len]   # (B,arbital_L)
        lang_prob = lang_prob_tensor.sum(1)                     # (B,1)
        
        if length_penalty:
            # eos prob -> eos loss
            eos_prob = torch.stack(eos_prob, dim = 1)
            for i in range(eos_prob.shape[0]):
                r_len = torch.arange(1,eos_prob.shape[1]+1,dtype=torch.float32)
                eos_prob[i] = eos_prob[i]*r_len.to(eos_prob.device)
                eos_prob[i, lang_length[i]:] = 0
            eos_loss = -eos_prob
            eos_loss = eos_loss.sum(1)/lang_length.float()
            eos_loss = eos_loss.mean()
        else:
            eos_loss = 0
            
        # Sum up log probabilities of samples
        return lang_tensor, lang_length, eos_loss, lang_prob

    def to_text(self, lang_onehot):
        texts = []
        lang = lang_onehot.argmax(2)
        for sample in lang.cpu().numpy():
            text = []
            for item in sample:
                text.append(data.ITOS[item])
                if item == data.EOS_IDX:
                    break
            texts.append(' '.join(text))
        return np.array(texts, dtype=np.unicode_)

In [None]:
class Speaker(nn.Module):
    def __init__(self, feat_model, embedding_module, hidden_size=100):
        super(Speaker, self).__init__()
        self.embedding = embedding_module
        self.feat_model = feat_model
        self.feat_size = feat_model.final_feat_dim
        self.embedding_dim = embedding_module.embedding_dim
        self.vocab_size = embedding_module.num_embeddings
        self.hidden_size = hidden_size

        self.gru = nn.GRU(self.embedding_dim, self.hidden_size)
        self.outputs2vocab = nn.Linear(self.hidden_size, self.vocab_size)
        # n_obj of feature size + 1/0 indicating target index
        self.init_h = nn.Linear(3 * (self.feat_size + 1), self.hidden_size)

    def embed_features(self, feats, targets):
        batch_size = feats.shape[0]
        n_obj = feats.shape[1]
        rest = feats.shape[2:]
        feats_flat = feats.view(batch_size * n_obj, *rest)
        feats_emb_flat = self.feat_model(feats_flat)
        feats_emb = feats_emb_flat.unsqueeze(1).view(batch_size, n_obj, -1)
        # Add targets
        #targets_onehot = to_onehot(targets)
        targets_onehot = targets
        feats_and_targets = torch.cat((feats_emb, targets_onehot.unsqueeze(2)), 2)
        ft_concat = feats_and_targets.view(batch_size, -1)
        return ft_concat

    def forward(self, feats, targets, greedy=False, activation='gumbel', tau = 1, length_penalty=False, max_len=40):
        """Sample from image features"""
        batch_size = feats.size(0)
        feats_emb = self.embed_features(feats, targets)
        # initialize hidden states using image features
        states = self.init_h(feats_emb)
        states = states.unsqueeze(0)
        # This contains are series of sampled onehot vectors
        lang = []
        if length_penalty:
            eos_prob = []
        if activation == 'multinomial':
            lang_prob = []
        else:
            lang_prob = None
        # And vector lengths
        lang_length = torch.ones(batch_size, dtype=torch.int64).to(feats.device)
        done_sampling = [False for _ in range(batch_size)]
        # first input is SOS token
        # (batch_size, n_vocab)
        inputs_onehot = torch.zeros(batch_size, self.vocab_size).to(feats.device)
        inputs_onehot[:, SOS] = 1.0
        # (batch_size, len, n_vocab)
        inputs_onehot = inputs_onehot.unsqueeze(1)
        # Add SOS to lang
        lang.append(inputs_onehot)
        # (B,L,D) to (L,B,D)
        inputs_onehot = inputs_onehot.transpose(0, 1)
        # compute embeddings
        # (1, batch_size, n_vocab) X (n_vocab, h) -> (1, batch_size, h)
        inputs = inputs_onehot @ self.embedding.weight

        for i in range(max_len - 2):  # Have room for SOS, EOS if never sampled
            # FIXME: This is inefficient since I do sampling even if we've
            # finished generating language.
            if all(done_sampling):
                break
            self.gru.flatten_parameters()
            outputs, states = self.gru(inputs, states)  # outputs: (L=1,B,H)
            outputs = outputs.squeeze(0)                # outputs: (B,H)
            outputs = self.outputs2vocab(outputs)       # outputs: (B,V)
            
            if greedy:
                predicted = outputs.max(1)[1]
                predicted = predicted.unsqueeze(1)
            else:
                #  outputs = F.softmax(outputs, dim=1)
                #  predicted = torch.multinomial(outputs, 1)
                # TODO: Need to let language model accept one-hot vectors.
                if activation=='gumbel'or activation==None:
                    predicted_onehot = F.gumbel_softmax(outputs, tau=tau, hard=True)
                elif activation=='softmax':
                    predicted_onehot = F.softmax(outputs/tau)
                elif activation=='softmax_noise':
                    predicted_onehot = F.gumbel_softmax(outputs, tau=tau, hard=False)
                elif activation == 'multinomial':
                    # Normal non-differentiable sampling from the RNN, trained with REINFORCE
                    TEMP = 5.0
                    idx_prob = F.log_softmax(outputs / TEMP, dim=1)
                    predicted = torch.multinomial(idx_prob.exp(), 1)
                    predicted_onehot = to_onehot(predicted, n=self.vocab_size)
                    predicted_logprob = torch.gather(idx_prob, 1, predicted)
                    lang_prob.append(predicted_logprob)
                else:
                    raise NotImplementedError(activation)
                    
                # Add to lang
                lang.append(predicted_onehot.unsqueeze(1))
                if length_penalty:
                    idx_prob = F.log_softmax(outputs, dim = 1)
                    eos_prob.append(idx_prob[:,EOS])

            predicted_npy = predicted_onehot.argmax(1).cpu().numpy()
            
            # Update language lengths
            for j, pred in enumerate(predicted_npy):
                if not done_sampling[j]:
                    lang_length[j] += 1
                if pred == EOS and activation in {'gumbel', 'multinomial'}:
                    done_sampling[j] = True

            # (1, batch_size, n_vocab) X (n_vocab, h) -> (1, batch_size, h)
            inputs = (predicted_onehot.unsqueeze(0)) @ self.embedding.weight

        # If multinomial, we need to run inputs once more to get the logprob of
        # EOS (in case we've sampled that far)
        if activation == 'multinomial':
            self.gru.flatten_parameters()
            outputs, states = self.gru(inputs, states)  # outputs: (L=1,B,H)
            outputs = outputs.squeeze(0)                # outputs: (B,H)
            outputs = self.outputs2vocab(outputs)       # outputs: (B,V)
            idx_prob = F.log_softmax(outputs, dim=1)
            lang_prob.append(idx_prob[:, EOS].unsqueeze(1))
            
        # Add EOS if we've never sampled it
        eos_onehot = torch.zeros(batch_size, 1, self.vocab_size).to(feats.device)
        eos_onehot[:, 0, EOS] = 1.0
        lang.append(eos_onehot)
        
        # Cut off the rest of the sentences
        for i, _ in enumerate(predicted_npy):
            if not done_sampling[i]:
                lang_length[i] += 1
            done_sampling[i] = True

        # Cat language tensors
        lang_tensor = torch.cat(lang, 1)
        
        for i in range(lang_tensor.shape[0]):
            lang_tensor[i, lang_length[i]:] = 0

        # Trim max length
        max_lang_len = lang_length.max()
        lang_tensor = lang_tensor[:, :max_lang_len, :]
        
        if activation == 'multinomial':
            lang_prob_tensor = torch.cat(lang_prob, 1)
            for i in range(lang_prob_tensor.shape[0]):
                lang_prob_tensor[i, lang_length[i]:] = 0
            lang_prob_tensor = lang_prob_tensor[:, :max_lang_len]
            lang_prob = lang_prob_tensor.sum(1)
        
        if length_penalty:
            # eos prob -> eos loss
            eos_prob = torch.stack(eos_prob, dim = 1)
            for i in range(eos_prob.shape[0]):
                r_len = torch.arange(1,eos_prob.shape[1]+1,dtype=torch.float32)
                eos_prob[i] = eos_prob[i]*r_len.to(eos_prob.device)
                eos_prob[i, lang_length[i]:] = 0
            eos_loss = -eos_prob
            eos_loss = eos_loss.sum(1)/lang_length.float()
            eos_loss = eos_loss.mean()
        else:
            eos_loss = 0
            
        # Sum up log probabilities of samples
        return lang_tensor, lang_length, eos_loss, lang_prob

    def to_text(self, lang_onehot):
        texts = []
        lang = lang_onehot.argmax(2)
        for sample in lang.cpu().numpy():
            text = []
            for item in sample:
                text.append(i2w[item])
                if item == EOS:
                    break
            texts.append(' '.join(text))
        return np.array(texts, dtype=np.unicode_)

In [None]:
from literal_listener_shapeworld import ShapeWorld_RNN_L0
from vision import ConvNet

emb_dim = 768
speaker_embs = nn.Embedding(len(w2i), emb_dim)
#speaker_feat = Imgs_Feature(output_size=1024)
speaker_feat = ConvNet(4)
speaker = Speaker(speaker_feat, speaker_embs)
speaker.to(device)

literal_listener = ShapeWorld_RNN_L0(len(w2i)).to(device)
literal_listener.load_state_dict(torch.load("model_params\shapeworld_rnn_full-data_100epoch_l0_last.pth",map_location=device))

max_len = 4

optimizer = optim.Adam(list(speaker.parameters()),lr=0.001)
criterion = nn.CrossEntropyLoss()

losss = []
accs = []

for i,(cols,label,lang) in enumerate(train_batch):
    cols, lang = cols.to(device).to(torch.float), lang.to(device)
    label = label.to(device).to(torch.float)
    optimizer.zero_grad()
    literal_listener.train()
    speaker.train()
    lang_tensor,lang_length,eos_loss,lang_prob = speaker(cols, label, length_penalty=False, max_len=max_len)
    #print(lang_tensor.shape)
    #print(lang_length.shape)
    #print(eos_loss)
    #print(lang_prob.shape)
    #print(lang.shape)
    lang_out = lang_tensor.view(lang_tensor.size(0)*lang_tensor.size(1), len(vocab['w2i'].keys()))
    lang_onehot = torch.vstack(tuple([to_onehot(sent.to(torch.int64) ,len(w2i.keys())).unsqueeze(0) for sent in lang]))
    #print(lang_onehot.shape)
    lang_in = lang_onehot.long().view(lang.size(0)*lang.size(1), len(vocab['w2i'].keys()))
    #print(lang_in.shape)
    #print(lang_out.shape)
    #print(torch.max(lang_in, 1)[1].shape)
    lang_loss = criterion(lang_out.cuda(), torch.max(lang_in, 1)[1].cuda())
    output_lang = lang_tensor.argmax(2)
    lis_labels = literal_listener(cols, output_lang)
    loss = criterion(lis_labels,label) + lang_loss + eos_loss*0.0001
    #loss = criterion(lang_tensor.to(torch.float),lang.to(torch.float))
    loss.backward()
    optimizer.step()
    if i%10 == 0:
        #print(lis_labels[0])
        #print(label[0])
        print("\nOriginal sentence:\n"+" ".join([i2w[idx] for idx in lang[0].to("cpu").tolist()]).replace(" <pad>",""))
        print("Generated sentence:\n"+" ".join([i2w[idx] for idx in lang_tensor.argmax(2)[0].to("cpu").tolist()]).replace(" <pad>","")+"\n")
        print("Loss: ",loss.item())
        losss.append(loss.item())
        pred_labels = torch.argmax(lis_labels,dim=1)
        correct_labels = torch.zeros(cols.shape[0])+2
        acc = (sum(correct_labels.to(device)==pred_labels)/len(correct_labels)).item()
        print("Accuracy:",acc)
        accs.append(acc)
    if i > 100: break


## Trainig the model

In [None]:
from literal_listener_shapeworld import ShapeWorld_RNN_L0
from vision import ConvNet

emb_dim = 768
speaker_embs = nn.Embedding(len(w2i), emb_dim)
#speaker_feat = Imgs_Feature(output_size=1024)
speaker_feat = ConvNet(4)
speaker = Speaker(speaker_feat, speaker_embs)
speaker.to(device)

literal_listener = ShapeWorld_RNN_L0(len(w2i)).to(device)
literal_listener.load_state_dict(torch.load("model_params\shapeworld_rnn_full-data_100epoch_l0_last.pth",map_location=device))

max_len = 4

optimizer = optim.Adam(speaker.parameters(),lr=0.01)
criterion = nn.CrossEntropyLoss()

epoch = 10

In [None]:
train_loss_list = []
test_loss_list = []
train_acc_list = []
test_acc_list = []
best_loss = 100
for i in range(epoch):
    print("##############################################")
    print("Epoch:{}/{}".format(i+1,epoch))
    train_loss = 0
    test_loss = 0
    train_acc = 0
    test_acc = 0

    # freeze the weights and L0.train for RNN loss backward
    #for weights in literal_listener.parameters():
    #    weights.requires_grad = False
    literal_listener.train()
    speaker.train()
    #print("Start Training")
    for cols,label,lang in train_batch:
        cols, lang = cols.to(device).to(torch.float), lang.to(device)
        label = label.to(device).to(torch.float)
        optimizer.zero_grad()
        lang_tensor,lang_length,eos_loss,lang_prob = speaker(cols, label, length_penalty=False, max_len=max_len)
        # for L1 loss
        output_lang = lang_tensor.argmax(2)
        # prob01 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] for batch,(sent,idxs) in enumerate(zip(lang_tensor,output_lang))]
        # prob01_sums = list(map(sum,prob01))
        # label02 = torch.zeros_like(label)
        # label02[:,1] = 1.0
        # lang_tensor1,lang_length,eos_loss,lang_prob = speaker(cols, label02, length_penalty=False, max_len=max_len)
        # prob02 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] for batch,(sent,idxs) in enumerate(zip(lang_tensor1,output_lang))]
        # prob02_sums = list(map(sum,prob02))
        # label03 = torch.zeros_like(label)
        # label03[:,2] = 1.0
        # lang_tensor2,lang_length,eos_loss,lang_prob = speaker(cols, label03, length_penalty=False, max_len=max_len)
        # prob03 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] for batch,(sent,idxs) in enumerate(zip(lang_tensor2,output_lang))]
        # prob03_sums = list(map(sum,prob03))
        # probs = torch.tensor(np.array([prob01_sums,prob02_sums,prob03_sums])).transpose(0,1)
        # l1_loss = criterion(probs.to(device),label) + eos_loss*0.0001
        # for L0 loss
        lis_labels = literal_listener(cols, output_lang).to(device)
        l0_loss = criterion(lis_labels,label)
        #loss = l0_loss + l1_loss + eos_loss*0.0001
        loss = l0_loss + eos_loss*0.0001
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        pred_labels = torch.argmax(lis_labels,dim=1)
        correct_labels = torch.zeros(pred_labels.shape[0])+2
        train_acc += sum(correct_labels.to(device)==pred_labels)/len(correct_labels)
    batch_train_loss = train_loss/len(train_batch)
    batch_train_acc = train_acc/len(train_batch)

    speaker.eval()
    #print("Start Evaluation")
    for cols,label,lang in test_batch:
        cols, lang = cols.to(device).to(torch.float), lang.to(device)
        label = label.to(device).to(torch.float)
        lang_tensor,lang_length,eos_loss,lang_prob = speaker(cols, label, length_penalty=False, max_len=max_len)
        # for L1 loss
        output_lang = lang_tensor.argmax(2)
        # prob01 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] for batch,(sent,idxs) in enumerate(zip(lang_tensor,output_lang))]
        # prob01_sums = list(map(sum,prob01))
        # label02 = torch.zeros_like(label)
        # label02[:,1] = 1.0
        # lang_tensor1,lang_length,eos_loss,lang_prob = speaker(cols, label02, length_penalty=False, max_len=max_len)
        # prob02 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] for batch,(sent,idxs) in enumerate(zip(lang_tensor1,output_lang))]
        # prob02_sums = list(map(sum,prob02))
        # label03 = torch.zeros_like(label)
        # label03[:,2] = 1.0
        # lang_tensor2,lang_length,eos_loss,lang_prob = speaker(cols, label03, length_penalty=False, max_len=max_len)
        # prob03 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] for batch,(sent,idxs) in enumerate(zip(lang_tensor2,output_lang))]
        # prob03_sums = list(map(sum,prob03))
        # probs = torch.tensor(np.array([prob01_sums,prob02_sums,prob03_sums])).transpose(0,1)
        # l1_loss = criterion(probs.to(device),label) + eos_loss*0.0001
        # for L0 loss
        lis_labels = literal_listener(cols, output_lang).to(device)
        l0_loss = criterion(lis_labels,label)
        #loss = l0_loss + l1_loss + eos_loss*0.0001
        loss = l0_loss + eos_loss*0.0001
        test_loss += loss.item()
        pred_labels = torch.argmax(lis_labels,dim=1)
        correct_labels = torch.zeros(pred_labels.shape[0])+2
        test_acc += (sum(correct_labels.to(device)==pred_labels)/len(correct_labels)).item()
    batch_test_loss = test_loss/len(test_batch)
    batch_test_acc = test_acc/len(test_batch)

    print("Train Loss:{:.2E}, Test Loss:{:.2E}".format(batch_train_loss,batch_test_loss))
    print("Train Acc:{:.2E}, Test Acc:{:.2E}".format(batch_train_acc,batch_test_acc))
    train_loss_list.append(batch_train_loss)
    test_loss_list.append(batch_test_loss)
    train_acc_list.append(batch_train_acc)
    test_acc_list.append(batch_test_acc)
    if batch_test_loss < best_loss:
        # save
        torch.save(speaker.to(device).state_dict(),"model_params/color_S1_lis=emb-rnn-L0_original_rnn_no-penalty_L0-loss.pth")
        best_loss = batch_test_loss


In [None]:
# visualization
plt.figure()
plt.title("Train and Test Loss")
plt.xlabel("epoch")
plt.ylabel("Literal_Listener_loss")
plt.plot(range(1,epoch+1),train_loss_list,"b-",label="train_loss")
plt.plot(range(1,epoch+1),test_loss_list,"r--",label="test_loss")
plt.legend()
plt.show()
train_acc_list = [float(acc) for acc in train_acc_list]
test_acc_list = [float(acc) for acc in test_acc_list]
plt.figure()
plt.title("Train and Test Accuracy")
plt.xlabel("epoch")
plt.ylabel("Accuracy")
plt.plot(range(1,epoch+1),train_acc_list,"b-",label="train_acc")
plt.plot(range(1,epoch+1),test_acc_list,"r--",label="test_acc")
plt.legend()
plt.show()

## Accuracy test

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[2]))
print(d["imgs"].shape)
print(d["labels"].shape)
print(d["langs"].shape)
check_raw_data(d["imgs"],d["labels"],d["langs"],id=1)
eval_batch = DataLoader(ShapeWorld(d, vocab), batch_size=32, shuffle=False)

In [None]:
import re
def check_data(imgs, label, g_langs, c_langs):
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 4))
    title = re.sub(r"<sos>|<eos>","","Generated: "+" ".join(g_langs)+"\n\nCorrect: "+" ".join(c_langs))
    fig.suptitle(title)
    for i,(l,img) in enumerate(zip(label,imgs)):
        img = img.transpose(2,0)
        axes[i].imshow(img)
        if l==1: axes[i].set_title("Correct")
    plt.show()

### L0 Accuracy

In [None]:
#speaker_embs = nn.Embedding(len(w2i), emb_dim)
#speaker_feat = Imgs_Feature(output_size=1024)
#speaker_feat = ConvNet(4)
#speaker = Speaker(speaker_feat, speaker_embs)
#speaker.load_state_dict(torch.load("model_params\color_S1_lis=emb-rnn-L0_original_rnn_no-penalty+cpu-lis.pth",map_location=device))
#speaker.to(device)

losss = []
accs = []
for i,(cols,label,lang) in enumerate(eval_batch):
    cols, lang = cols.to(device).to(torch.float), lang.to(device)
    label = label.to(device).to(torch.float)
    literal_listener.eval()
    speaker.eval()
    lang_tensor,lang_length,eos_loss,lang_prob = speaker(cols, label, length_penalty=False, max_len=max_len)
    output_lang = lang_tensor.argmax(2)
    lis_labels = literal_listener(cols, output_lang)
    loss = criterion(lis_labels,label) + eos_loss*0.0001
    if i%10 == 0:
        imgs = cols[0].to("cpu")
        c_langs = [i2w[idx] for idx in lang[0].to("cpu").tolist()]
        g_langs = [i2w[idx] for idx in lang_tensor.argmax(2)[0].to("cpu").tolist()]
        label = label[0]
        check_data(imgs, label, g_langs, c_langs)
        #print("Loss: ",loss.item())
        losss.append(loss.item())
        pred_labels = torch.argmax(lis_labels,dim=1)
        correct_labels = torch.zeros(cols.shape[0])
        acc = sum(correct_labels.to(device)==pred_labels)/len(correct_labels)
        #print("Accuracy:",acc)
        accs.append(acc.item())
    if i > 100: break

print("Accuracy:,",np.mean(accs))

### L1 Accuracy

In [None]:
losss = []
accs = []
for i,(cols,label,lang) in enumerate(eval_batch):
    cols, lang = cols.to(device).to(torch.float), lang.to(device)
    label = label.to(device).to(torch.float)
    literal_listener.eval()
    speaker.eval()
    lang_tensor,lang_length,eos_loss,lang_prob = speaker(cols, label, length_penalty=False, max_len=max_len)
    output_lang1 = lang_tensor.argmax(2)
    # for 2nd image
    label02 = torch.zeros_like(label)
    label02[:,1] = 1.0
    lang_tensor1,lang_length,eos_loss,lang_prob = speaker(cols, label02, length_penalty=False, max_len=max_len)
    output_lang2 = lang_tensor1.argmax(2)
    # for 3rd image
    label03 = torch.zeros_like(label)
    label03[:,2] = 1.0
    lang_tensor2,lang_length,eos_loss,lang_prob = speaker(cols, label03, length_penalty=False, max_len=max_len)
    output_lang3 = lang_tensor2.argmax(2)
    
    prob01 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_tensor,output_lang3))]
    prob01_sums = list(map(sum,prob01))
    prob02 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_tensor1,output_lang3))]
    prob02_sums = list(map(sum,prob02))
    prob03 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_tensor2,output_lang3))]
    prob03_sums = list(map(sum,prob03))
    probs = torch.tensor(np.array([prob01_sums,prob02_sums,prob03_sums])).transpose(0,1)
    #print(probs)
    loss = criterion(probs.to(device),label) + eos_loss*0.0001
    losss.append(loss.item())
    pred_labels = torch.argmax(probs,dim=1)
    correct_labels = torch.zeros(cols.shape[0])+2
    acc = sum(correct_labels==pred_labels)/len(correct_labels)
    accs.append(acc.item())
    
    if i%10 == 0:
        imgs = cols[0].to("cpu")
        c_langs = [i2w[idx] for idx in lang[0].to("cpu").tolist()]
        g_langs = [i2w[idx] for idx in output_lang[0].to("cpu").tolist()] \
                + ["|"]+ [i2w[idx] for idx in output_lang2[0].to("cpu").tolist()] \
                + ["|"]+ [i2w[idx] for idx in output_lang3[0].to("cpu").tolist()]
        label = label[0]
        check_data(imgs, label, g_langs, c_langs)
        print(torch.exp(probs[0]))
        print(torch.where(torch.exp(probs[0])>0.1,1,0))
        #print("Loss: ",loss.item())
    if i > 100: break

print("Accuracy:,",np.mean(accs))