In [None]:
from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from shapeworld_data import load_raw_data, get_vocab, ShapeWorld

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## Prepare Data

In [None]:
def check_raw_data(imgs, labels, langs, id=0):
    data = list(zip(imgs,labels,langs))
    img_list,label,lang = data[id]
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 2))
    fig.suptitle(" ".join(lang))
    for i,(l,img) in enumerate(zip(label,img_list)):
        img = img.transpose((2,1,0))
        axes[i].imshow(img)
        if l==1: axes[i].set_title("Correct")
    plt.show()

In [None]:
root = Path(os.path.abspath('')).parent.parent.absolute()
data_path = os.path.join(root,"data\shapeworld_np")
print(data_path)
data_list = os.listdir(data_path)
print(data_list)

### Generating vocab_dict 

In [None]:
vocab = get_vocab([os.path.join(data_path,d) for d in data_list])
print(vocab["w2i"])

COLOR = {"white":[1,0,0,0,0,0], "green":[0,1,0,0,0,0], "gray":[0,0,1,0,0,0], "yellow":[0,0,0,1,0,0], "red":[0,0,0,0,1,0], "blue":[0,0,0,0,0,1], "other":[0,0,0,0,0,0]}
SHAPE = {"shape":[0,0,0,0], "square":[1,0,0,0], "circle":[0,1,0,0], "rectangle":[0,0,1,0], "ellipse":[0,0,0,1]}

w2i = vocab["w2i"]
i2w = vocab["i2w"]

PAD = 0
SOS = 1
EOS = 2
UNK = 3

### Prepare the Data loader

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[0]))
imgs = d["imgs"]
labels = d["labels"]
langs = d["langs"]
check_raw_data(d["imgs"],d["labels"],d["langs"])
for i in range(1,4):
    d = load_raw_data(os.path.join(data_path,data_list[i]))
    imgs = np.vstack((imgs,d["imgs"]))
    labels = np.vstack((labels,d["labels"]))
    langs = np.hstack((langs,d["langs"]))
    check_raw_data(d["imgs"],d["labels"],d["langs"])
d["imgs"] = imgs
d["labels"] = labels
d["langs"] = langs
print(d["imgs"].shape, d["labels"].shape, d["langs"].shape)
print(len(ShapeWorld(d, vocab)))
train_batch = DataLoader(ShapeWorld(d, vocab), batch_size=32, shuffle=False)

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[-1]))
print(d["imgs"].shape)
print(d["labels"].shape)
print(d["langs"].shape)
check_raw_data(d["imgs"],d["labels"],d["langs"],id=1)
test_batch = DataLoader(ShapeWorld(d, vocab), batch_size=32, shuffle=False)

## Model

In [None]:
def to_onehot(y, n):
    y_onehot = torch.zeros(y.shape[0], n).to(y.device)
    y_onehot.scatter_(1, y.view(-1, 1), 1)
    return y_onehot

In [None]:
class Imgs_emb_DeepSet(nn.Module):
    def __init__(self, input_size=10, output_size=20):
        super(Imgs_emb_DeepSet, self).__init__()
        self.linear1 = nn.Linear(input_size, output_size)
        self.linear2 = nn.Linear(input_size, output_size)
        self.linear3 = nn.Linear(output_size, output_size)
    
    def forward(self, img_emb1, img_emb2):
        img_embs = F.relu(self.linear1(img_emb1)) + F.relu(self.linear2(img_emb2))
        img_embs = self.linear3(img_embs)
        return img_embs

In [None]:
from literal_listener_shapeworld import CNN_encoder

class CS_CNN_encoder_Feature(nn.Module):
    def __init__(self, input_size=10, output_size=10):
        super(CS_CNN_encoder_Feature, self).__init__()
        self.cnn_color_encoder = CNN_encoder(6)
        self.cnn_color_encoder.load_state_dict(torch.load("model_params/shapeworld_original-cnn_color_model.pth",map_location=device))
        for params in self.cnn_color_encoder.parameters(): params.requires_grad = False
        self.cnn_shape_encoder = CNN_encoder(4)
        self.cnn_shape_encoder.load_state_dict(torch.load("model_params/shapeworld_original-cnn_shape_model.pth",map_location=device))
        for params in self.cnn_shape_encoder.parameters(): params.requires_grad = False
        self.deepset_size = 6+4
        self.deepset = Imgs_emb_DeepSet(self.deepset_size, self.deepset_size)
        self.linear = nn.Linear(input_size+self.deepset_size, output_size)

    def get_feat_emb(self,feat):
        col_embs = self.cnn_color_encoder(feat)
        shape_embs = self.cnn_shape_encoder(feat)
        img_embs = torch.hstack((col_embs,shape_embs))
        return img_embs
    
    def forward(self,feats,labels):
        idxs = [0,1,2]
        target_idx = int(torch.argmax(labels))
        idxs.remove(target_idx)
        other_idx1,other_idx2 = idxs[0],idxs[1]
        target_img,other_img1,other_img2 = feats[:,target_idx], feats[:,other_idx1], feats[:,other_idx2]
        target_embs = self.get_feat_emb(target_img)
        other_embs1 = self.get_feat_emb(other_img1)
        other_embs2 = self.get_feat_emb(other_img2)                 # (batch_size,10)
        other_embs = F.relu(self.deepset(other_embs1,other_embs2))  # (batch_size,10)
        embs = torch.hstack((target_embs,other_embs))               # (batch_size,20)
        feat = self.linear(embs)
        return feat

In [None]:
class Speaker(nn.Module):
    def __init__(self, feat_model, embedding_module, feat_size=1024, hidden_size=100):
        super(Speaker, self).__init__()
        self.embedding = embedding_module
        self.feat_model = feat_model
        self.embedding_dim = embedding_module.embedding_dim
        self.vocab_size = embedding_module.num_embeddings
        self.hidden_size = hidden_size
        self.gru = nn.GRU(self.embedding_dim, self.hidden_size, bidirectional=False)
        self.outputs2vocab = nn.Linear(self.hidden_size*1, self.vocab_size)                             # *2 for bidirectioanl
        self.init_h1 = nn.Linear(feat_size, self.hidden_size)
        self.init_h2 = nn.Linear(self.hidden_size, self.hidden_size)

    def forward(self,feats,labels,lang,x_lens):
        feats_emb = self.feat_model(feats, labels)
        states = self.init_h2(F.relu(self.init_h1(feats_emb))).unsqueeze(0)
        #print(lang.shape)
        embedded = self.embedding(lang)
        embedded = embedded.transpose(0, 1)                               # (B,L,D) to (L,B,D)
        #print(embedded.shape)
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded,x_lens.to("cpu"),enforce_sorted=False)
        packed_outputs,states = self.gru(packed_embedded, states)
        outputs, output_lens = nn.utils.rnn.pad_packed_sequence(packed_outputs)
        outputs = self.outputs2vocab(outputs)
        predicted_onehot = F.softmax(outputs,dim=-1)
        return predicted_onehot.transpose(0, 1)

    def generate(self, feats, labels, tau=1, max_len=40):
        batch_size = feats.size(0)

        # initialize hidden states using image features
        feats_emb = self.feat_model(feats,labels)
        states = self.init_h2(F.relu(self.init_h1(feats_emb))).unsqueeze(0)
        #states = torch.vstack((states, states))
        # This contains are series of sampled onehot vectors
        lang = []

        # first input is SOS token
        inputs_onehot = torch.zeros(batch_size, self.vocab_size).to(feats.device)   # (batch_size, n_vocab)
        inputs_onehot[:, SOS] = 1.0
        inputs_onehot = inputs_onehot.unsqueeze(1)                                  # (batch_size, 1, n_vocab)
        lang.append(inputs_onehot)                                                  # Add SOS to lang
        
        inputs_onehot = inputs_onehot.transpose(0, 1)                               # (B,L,D) to (L,B,D)
        inputs = inputs_onehot @ self.embedding.weight                              # (1,batch_size, n_vocab) X (n_vocab, h) -> (1,batch_size, h)
        for i in range(max_len - 2):  # Have room for SOS, EOS if never sampled
            self.gru.flatten_parameters()
            outputs, states = self.gru(inputs, states)                          # outputs: (L=1,B,H)
            outputs = outputs.squeeze()                                         # outputs: (B,H)
            outputs = self.outputs2vocab(outputs)                               # outputs: (B,V)
            predicted_onehot = F.softmax(outputs,dim=-1)    # (B,V)
            lang.append(predicted_onehot.unsqueeze(1))                          # Add to lang
            inputs = (predicted_onehot.unsqueeze(0)) @ self.embedding.weight    # (1, batch_size, n_vocab) X (n_vocab, h) -> (1, batch_size, h)
            
        # Add EOS if we've never sampled it
        eos_onehot = torch.zeros(batch_size, 1, self.vocab_size).to(feats.device)
        eos_onehot[:, 0, EOS] = 1.0
        lang.append(eos_onehot)
        # Cat language tensors
        lang_tensor = torch.cat(lang, 1)                    # (B,max_L,V)
        return lang_tensor

## Trainig the model

In [None]:
def check_batch(cols,langs,label,output,gen_lang=None,id=0):
    if gen_lang is not None:
        data = list(zip(cols,langs,label,output,gen_lang))
        img_list,lang,label,output,gen_lang = data[id]
        lang = " ".join([i2w[int(idx)] for idx in lang if idx not in [SOS,EOS,PAD]])
        output = " ".join([i2w[int(idx)] for idx in output if idx not in [SOS,EOS,PAD]])
        gen_lang = " ".join([i2w[int(idx)] for idx in gen_lang if idx not in [SOS,EOS,PAD]])
        text = "Corretc: "+lang+"\nOutput: "+output+"\nGenerated: "+gen_lang
    else:
        data = list(zip(cols,langs,label,output))
        img_list,lang,label,output = data[id]
        lang = " ".join([i2w[int(idx)] for idx in lang])
        output = " ".join([i2w[int(idx)] for idx in output])
        text = "Corretc: "+lang+"\nOutput: "+output
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 4))
    fig.suptitle(text)
    for i,(l,img) in enumerate(zip(label,img_list)):
        img = img.transpose(2,0).to("cpu").detach().numpy()
        axes[i].imshow(img)
        if l==1: axes[i].set_title("Correct")
    plt.show()

### Setting

In [None]:
from literal_listener_shapeworld import ShapeWorld_RNN_L0

emb_dim = 768
feat_dim = 10
speaker_embs = nn.Embedding(len(w2i), emb_dim)
speaker_feat = CS_CNN_encoder_Feature(output_size=feat_dim)
speaker = Speaker(speaker_feat, speaker_embs, feat_size=feat_dim)
speaker.to(device)

literal_listener = ShapeWorld_RNN_L0(len(w2i)).to(device)
literal_listener.load_state_dict(torch.load("model_params\shapeworld_rnn_full-data_100epoch_l0_last.pth",map_location=device))

max_len = 4

optimizer = optim.Adam(speaker.parameters(),lr=0.001)
criterion = nn.CrossEntropyLoss()

epoch = 10

loss_pth = "model_params/shapeworld_S0_lis=emb-rnn_CS-CNN_encoder_loss=Lang_full-data_best-loss.pth"
acc_pth = "model_params/shapeworld_S0_lis=emb-rnn_CS-CNN_encoder_loss=Lang_full-data_best-acc.pth"

In [None]:
from tqdm import tqdm

train_loss_list = []
test_loss_list = []
train_acc_list = []
test_acc_list = []
best_loss = 100
best_acc = 0
for i in range(epoch):
    print("##############################################")
    print("Epoch:{}/{}".format(i+1,epoch))
    train_loss = 0
    test_loss = 0
    train_acc = 0
    test_acc = 0

    literal_listener.train()
    speaker.train()
    #print("Start Training")
    for cols,label,lang in tqdm(train_batch):
        cols, lang, label = cols.to(device).to(torch.float), lang.to(device), label.to(device).to(torch.float)
        optimizer.zero_grad()
        x_lens = torch.tensor(np.array([3]*len(cols))).to(device)
        lang_tensor = speaker(cols, label, lang[:,:-1], x_lens=x_lens)
        output_max_len = lang_tensor.size(1)
        lang_onehot = torch.vstack(tuple([to_onehot(sent.to(torch.int64) ,len(w2i)).unsqueeze(0) for sent in lang]))
        lang_target = lang_onehot[:,1:output_max_len+1,:]
        #print(lang_tensor.shape, lang_target.shape)
        loss = criterion(lang_tensor.reshape(-1, len(w2i)), lang_target.reshape(-1,len(w2i)))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        output_lang = lang_tensor.argmax(2)
        lis_labels = literal_listener(cols, output_lang)
        pred_labels = torch.argmax(lis_labels,dim=1)
        correct_labels = torch.zeros(pred_labels.shape[0])
        train_acc += sum(correct_labels.to(device)==pred_labels)/len(correct_labels)
        #break
    batch_train_loss = train_loss/len(train_batch)
    batch_train_acc = train_acc/len(train_batch)

    speaker.eval()
    #print("Start Evaluation")
    with torch.no_grad():
        for cols,label,lang in tqdm(test_batch):
            cols, lang, label = cols.to(device).to(torch.float), lang.to(device), label.to(device).to(torch.float)
            x_lens = torch.tensor(np.array([3]*len(cols))).to(device)
            lang_tensor = speaker(cols, label, lang[:,:-1], x_lens=x_lens)
            gen_lang_tensor = speaker.generate(cols, label, max_len=4)
            output_max_len = lang_tensor.size(1)
            lang_onehot = torch.vstack(tuple([to_onehot(sent.to(torch.int64) ,len(w2i)).unsqueeze(0) for sent in lang]))
            lang_target = lang_onehot[:,1:output_max_len+1,:]
            #print(lang_tensor.shape, lang_target.shape)
            loss = criterion(lang_tensor.reshape(-1, len(w2i)), lang_target.reshape(-1, len(w2i)))
            test_loss += loss.item()
            output_lang = lang_tensor.argmax(2)
            gen_output_lang = gen_lang_tensor.argmax(2)
            lis_labels = literal_listener(cols, gen_output_lang)
            pred_labels = torch.argmax(lis_labels,dim=1)
            correct_labels = torch.zeros(pred_labels.shape[0])
            test_acc += (sum(correct_labels.to(device)==pred_labels)/len(correct_labels)).item()
            #break
    batch_test_loss = test_loss/len(test_batch)
    batch_test_acc = test_acc/len(test_batch)

    print("Train Loss:{:.2E}, Test Loss:{:.2E}".format(batch_train_loss,batch_test_loss))
    print("Train Acc:{:.2E}, Test Acc:{:.2E}".format(batch_train_acc,batch_test_acc))
    train_loss_list.append(batch_train_loss)
    test_loss_list.append(batch_test_loss)
    train_acc_list.append(batch_train_acc)
    test_acc_list.append(batch_test_acc)
    if i%2==0: check_batch(cols,lang,label,output_lang,gen_output_lang,id=np.random.randint(len(cols)))
    if batch_test_loss < best_loss:
        print("Best loss saved ...")
        torch.save(speaker.to(device).state_dict(),loss_pth)
        best_loss = batch_test_loss
    if batch_test_acc > best_acc:
        print("Best acc saved ...")
        torch.save(speaker.to(device).state_dict(),acc_pth)
        best_acc = batch_test_acc
    #break
    


## Accuracy test

In [None]:
from literal_listener_shapeworld import ShapeWorld_RNN_L0

emb_dim = 768
feat_dim = 10
speaker_embs = nn.Embedding(len(w2i), emb_dim)
speaker_feat = CS_CNN_encoder_Feature(output_size=feat_dim)
speaker = Speaker(speaker_feat, speaker_embs, feat_size=feat_dim)
speaker.load_state_dict(torch.load(acc_pth,map_location=device))
speaker.to(device)

literal_listener = ShapeWorld_RNN_L0(len(w2i)).to(device)
literal_listener.load_state_dict(torch.load("model_params\shapeworld_rnn_full-data_100epoch_l0_last.pth",map_location=device))

max_len = 4

criterion = nn.CrossEntropyLoss()

In [None]:
import re
def check_data(imgs, label, g_langs, c_langs, text="Correct"):
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 4))
    title = re.sub(r"<sos>|<eos>","","Generated: "+" ".join(g_langs)+"\n\nCorrect: "+" ".join(c_langs))
    fig.suptitle(title)
    for i,(l,img) in enumerate(zip(label,imgs)):
        img = img.transpose(2,0)
        axes[i].imshow(img)
        if l==1: axes[i].set_title(text)
    plt.show()

### L0 accuracy

In [None]:
losss = []
accs = []
for i,(cols,label,lang) in enumerate(test_batch):
    cols, lang = cols.to(device).to(torch.float), lang.to(device)
    label = label.to(device).to(torch.float)
    literal_listener.eval()
    speaker.eval()
    lang_tensor = speaker.generate(cols, label, max_len=max_len)
    output_lang = lang_tensor.argmax(2)
    lis_labels = literal_listener(cols, output_lang)
    loss = criterion(lis_labels,label)
    losss.append(loss.item())
    pred_labels = torch.argmax(lis_labels,dim=1)
    correct_labels = torch.zeros(cols.shape[0])
    acc = sum(correct_labels.to(device)==pred_labels)/len(correct_labels)
    accs.append(acc.item())
    if i%10 == 0:
        imgs = cols[0].to("cpu")
        c_langs = [i2w[idx] for idx in lang[0].to("cpu").tolist()]
        g_langs = [i2w[idx] for idx in output_lang[0].to("cpu").tolist()]
        label = lis_labels[0]
        check_data(imgs, label, g_langs, c_langs, text="L0 prediction")

print("Accuracy:,",np.mean(accs))


### L1 accuracy

In [None]:
def get_prob_labels(lang_probs):
    lang_pred = []
    for probs in lang_probs:
        if probs[0]==probs[1] and probs[1]==probs[2]: # all same
            lang_pred.append(int(np.random.randint(3)))
        elif probs[0]==probs[1] and max(probs)==probs[0]:
            lang_pred.append(int(0 if np.random.randint(2)==0 else 1))
        elif probs[1]==probs[2] and max(probs)==probs[1]:
            lang_pred.append(int(1 if np.random.randint(2)==0 else 2))
        elif probs[0]==probs[2] and max(probs)==probs[1]:
            lang_pred.append(int(0 if np.random.randint(2)==0 else 2))
        else:
            lang_pred.append(int(torch.argmax(probs)))
    return np.array(lang_pred)

In [None]:
import re
def check_data(imgs, label, g_langs, c_langs, text="Correct"):
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 4))
    title = re.sub(r"<sos>|<eos>","","Generated: "+" ".join(g_langs)+"\n\nCorrect: "+" ".join(c_langs))
    fig.suptitle(title)
    for i,img in enumerate(imgs):
        img = img.transpose(2,0)
        axes[i].imshow(img)
        if label==i: axes[i].set_title(text)
    plt.show()

In [None]:
losss = []
accs = []
for i,(cols,label,lang) in enumerate(test_batch):
    cols, lang, label = cols.to(device).to(torch.float), lang.to(device), label.to(device).to(torch.float)
    speaker.eval()
    # for 1st image
    label01 = torch.zeros_like(label)
    label01[:,0] = 1.0
    lang_tensor1 = speaker.generate(cols, label01, max_len=max_len)
    output_lang1 = lang_tensor1.argmax(2)
    # for 2nd image
    label02 = torch.zeros_like(label)
    label02[:,1] = 1.0
    lang_tensor2 = speaker.generate(cols, label02, max_len=max_len)
    output_lang2 = lang_tensor2.argmax(2)
    # for 3rd image
    label03 = torch.zeros_like(label)
    label03[:,2] = 1.0
    lang_tensor3 = speaker.generate(cols, label03, max_len=max_len)
    output_lang3 = lang_tensor3.argmax(2)
    # compute probs
    prob01 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_tensor1,lang))]
    prob01_sums = list(map(sum,prob01))
    prob02 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_tensor2,lang))]
    prob02_sums = list(map(sum,prob02))
    prob03 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_tensor3,lang))]
    prob03_sums = list(map(sum,prob03))
    probs = torch.tensor(np.array([prob01_sums,prob02_sums,prob03_sums])).transpose(0,1)
    #pred_labels = torch.argmax(probs,dim=1)
    pred_labels = get_prob_labels(probs)
    correct_labels = np.zeros(cols.shape[0])
    acc = sum(correct_labels==pred_labels)/len(correct_labels)
    accs.append(acc.item())
    
    if i%10 == 0:
        imgs = cols[0].to("cpu")
        c_langs = [i2w[idx] for idx in lang[0].to("cpu").tolist()]
        g_langs = [i2w[idx] for idx in output_lang1[0].to("cpu").tolist()] \
                + ["|"]+ [i2w[idx] for idx in output_lang2[0].to("cpu").tolist()] \
                + ["|"]+ [i2w[idx] for idx in output_lang3[0].to("cpu").tolist()]
        label = label[0]
        check_data(imgs, pred_labels[0], g_langs, c_langs, text="L1 prediction")
        print(torch.exp(probs[0]))
        print(torch.where(torch.exp(probs[0])>0.1,1,0))
        #print("Loss: ",loss.item())
    if i > 100: break

print("Accuracy:,",np.mean(accs))