In [None]:
from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from shapeworld_data import load_raw_data, get_vocab, ShapeWorld


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## Prepare Data

In [None]:
def check_raw_data(imgs, labels, langs, id=0):
    data = list(zip(imgs,labels,langs))
    img_list,label,lang = data[id]
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 2))
    fig.suptitle(" ".join(lang))
    for i,(l,img) in enumerate(zip(label,img_list)):
        img = img.transpose((2,1,0))
        axes[i].imshow(img)
        if l==1: axes[i].set_title("Correct")
    plt.show()

In [None]:
root = Path(os.path.abspath('')).parent.parent.absolute()
data_path = os.path.join(root,"data\shapeworld_np")
print(data_path)
data_list = os.listdir(data_path)
print(data_list)

### Generating vocab_dict 

In [None]:
vocab = get_vocab([os.path.join(data_path,d) for d in data_list])
print(vocab["w2i"])

COLOR = {"white":[1,0,0,0,0,0], "green":[0,1,0,0,0,0], "gray":[0,0,1,0,0,0], "yellow":[0,0,0,1,0,0], "red":[0,0,0,0,1,0], "blue":[0,0,0,0,0,1], "other":[0,0,0,0,0,0]}
SHAPE = {"shape":[0,0,0,0], "square":[1,0,0,0], "circle":[0,1,0,0], "rectangle":[0,0,1,0], "ellipse":[0,0,0,1]}

w2i = vocab["w2i"]
i2w = vocab["i2w"]

PAD = 0
SOS = 1
EOS = 2
UNK = 3

In [None]:
import re
def utter2tensor(utter):
    utter = re.sub(r"<sos>|<eos>|<PAD>|<UNK>","",utter)
    #print(utter)
    utters = utter.split(" ")
    if len(utters) == 1 and utters[0] in COLOR.keys():
        return torch.tensor(np.array(COLOR[utters[0]]+SHAPE["shape"]))
    elif len(utters) == 1 and utters[0] in SHAPE.keys():
        return torch.tensor(np.array(COLOR["other"]+SHAPE[utters[0]]))
    elif len(utters) == 2:
        return torch.tensor(np.array(COLOR[utters[0]]+SHAPE[utters[1]]))
    else:
        return torch.tensor(np.array(COLOR["other"]+SHAPE["shape"]))

In [None]:
col_list = list(COLOR.keys())
col_list[-1] = ""
shape_list = list(SHAPE.keys())
utter_list = [" ".join([w for w in (c+" "+s).split(" ") if w]) for c in col_list for s in shape_list+[""]]
for l in utter_list:
    print(l)

In [None]:
import re
def lang2class(utter):
    utter = " ".join([w for w in re.sub(r"<sos>|<eos>|<PAD>|<UNK>","",utter).split(" ") if w])
    #print(utter)
    onehot = torch.zeros(len(utter_list))
    onehot[utter_list.index(utter)] = 1.0
    return (onehot)

### Prepare the Data loader

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[0]))
imgs = d["imgs"]
labels = d["labels"]
langs = d["langs"]
check_raw_data(d["imgs"],d["labels"],d["langs"])
for i in range(1,4):
    d = load_raw_data(os.path.join(data_path,data_list[i]))
    imgs = np.vstack((imgs,d["imgs"]))
    labels = np.vstack((labels,d["labels"]))
    langs = np.hstack((langs,d["langs"]))
    check_raw_data(d["imgs"],d["labels"],d["langs"])
d["imgs"] = imgs
d["labels"] = labels
d["langs"] = langs

all_in = []
for l in d["langs"]:
    lang = " ".join(l)
    all_in.append(lang in utter_list)
print("All lang in list: ",all(all_in))

print(d["imgs"].shape, d["labels"].shape, d["langs"].shape)
print(len(ShapeWorld(d, vocab)))
train_batch = DataLoader(ShapeWorld(d, vocab), batch_size=64, shuffle=False)

all_classified = []
for cols,label,lang in train_batch:
    for i in range(len(cols)):
        utter = " ".join([i2w[int(idx)] for idx in lang[i]])
        try:
            vec = lang2class(utter)
            all_classified.append(True)
        except:
            print(utter," not in list")
print("All lang classified: ",all(all_classified))

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[-1]))
print(d["imgs"].shape)
print(d["labels"].shape)
print(d["langs"].shape)
check_raw_data(d["imgs"],d["labels"],d["langs"],id=1)
all_in = []
for l in d["langs"]:
    lang = " ".join(l)
    all_in.append(lang in utter_list)
print("All lang in list: ",all(all_in))
test_batch = DataLoader(ShapeWorld(d, vocab), batch_size=32, shuffle=False)

all_classified = []
for cols,label,lang in test_batch:
    for i in range(len(cols)):
        utter = " ".join([i2w[int(idx)] for idx in lang[i]])
        try:
            vec = lang2class(utter)
            all_classified.append(True)
        except:
            print(utter," not in list")
print("All lang classified: ",all(all_classified))

## Model

In [None]:
class Imgs_emb_DeepSet(nn.Module):
    def __init__(self, input_size=10, output_size=20):
        super(Imgs_emb_DeepSet, self).__init__()
        self.linear1 = nn.Linear(input_size, output_size)
        self.linear2 = nn.Linear(input_size, output_size)
        self.linear3 = nn.Linear(output_size, output_size)
    
    def forward(self, img_emb1, img_emb2):
        img_embs = F.relu(self.linear1(img_emb1)) + F.relu(self.linear2(img_emb2))
        img_embs = self.linear3(img_embs)
        return img_embs

In [None]:
from literal_listener_shapeworld import CNN_encoder

class CS_CNN_encoder_Feature(nn.Module):
    def __init__(self, input_size=10, output_size=10):
        super(CS_CNN_encoder_Feature, self).__init__()
        self.cnn_color_encoder = CNN_encoder(6)
        self.cnn_color_encoder.load_state_dict(torch.load("model_params/cnn_color_model_best-loss.pth",map_location=device))
        for params in self.cnn_color_encoder.parameters(): params.requires_grad = False
        self.cnn_shape_encoder = CNN_encoder(4)
        self.cnn_shape_encoder.load_state_dict(torch.load("model_params/cnn_shape_model_best-loss.pth",map_location=device))
        for params in self.cnn_shape_encoder.parameters(): params.requires_grad = False
        self.deepset_size = 6+4
        self.deepset = Imgs_emb_DeepSet(self.deepset_size, self.deepset_size*2)

    def get_feat_emb(self,feat):
        col_embs = self.cnn_color_encoder(feat)
        shape_embs = self.cnn_shape_encoder(feat)
        img_embs = torch.hstack((col_embs,shape_embs))
        return img_embs
    
    def forward(self,feats,labels):
        idxs = [0,1,2]
        target_idx = int(torch.argmax(labels))
        idxs.remove(target_idx)
        other_idx1,other_idx2 = idxs[0],idxs[1]
        target_img,other_img1,other_img2 = feats[:,target_idx], feats[:,other_idx1], feats[:,other_idx2]
        target_embs = self.get_feat_emb(target_img)
        other_embs1 = self.get_feat_emb(other_img1)
        other_embs2 = self.get_feat_emb(other_img2)                 # (batch_size,10)
        other_embs = self.deepset(other_embs1,other_embs2)          # (batch_size,20)
        feat = torch.hstack((target_embs,other_embs,labels))        # (batch_size,33)
        #feat = self.linear(embs)                                    # (batch_size,output_size)
        return feat,target_embs

In [None]:
class Speaker(nn.Module):
    def __init__(self, feat_model, feat_size, output_size, hidden_size=100):
        super(Speaker, self).__init__()
        self.feat_model = feat_model
        self.linear1 = nn.Linear(feat_size,hidden_size)
        self.linear2 = nn.Linear(hidden_size,hidden_size)
        self.linear3 = nn.Linear(hidden_size,output_size)

    def forward(self, feats, labels):
        feats_emb,target_embs = self.feat_model(feats,labels)
        x = F.relu(self.linear1(feats_emb))
        x = F.relu(self.linear2(x))
        y_prob = F.softmax(self.linear3(x),dim=-1)
        return y_prob,target_embs

## Trainig the model

In [None]:
def check_batch(cols,langs,label,onehot_classs,target,target2,id=0):
    data = list(zip(cols,langs,label,onehot_classs,target,target2))
    img_list,lang,label,onehot_class,target,target2 = data[id]
    lang = " ".join([i2w[int(idx)] for idx in lang])
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 4))
    fig.suptitle(" ".join(lang)+"\nClassified: "+utter_list[int(torch.argmax(onehot_class))]+"\nCNN insode model: "+target+"\nFreezed CNN: "+target2)
    for i,(l,img) in enumerate(zip(label,img_list)):
        img = img.transpose(2,0).to("cpu").detach().numpy()
        axes[i].imshow(img)
        if l==1: axes[i].set_title("Correct")
    plt.show()

In [None]:
from literal_listener_shapeworld import ShapeWorld_RNN_L0

emb_dim = 768
feat_dim = 10
output_dim = len(utter_list)
speaker_feat = CS_CNN_encoder_Feature(input_size=feat_dim,output_size=feat_dim)
speaker = Speaker(speaker_feat, feat_size=feat_dim*3+3,output_size=output_dim)
speaker.to(device)

literal_listener = ShapeWorld_RNN_L0(len(w2i)).to(device)
literal_listener.load_state_dict(torch.load("model_params\shapeworld_rnn_full-data_100epoch_l0_last.pth",map_location=device))

cnn_color_encoder = CNN_encoder(6).to(device)
cnn_color_encoder.load_state_dict(torch.load("model_params/cnn_color_model_best-loss.pth",map_location=device))
for params in cnn_color_encoder.parameters(): params.requires_grad = False
cnn_shape_encoder = CNN_encoder(4).to(device)
cnn_shape_encoder.load_state_dict(torch.load("model_params/cnn_shape_model_best-loss.pth",map_location=device))
for params in cnn_shape_encoder.parameters(): params.requires_grad = False

optimizer = optim.Adam(speaker.parameters(),lr=0.001)
criterion = nn.CrossEntropyLoss()

epoch = 10

### start

In [None]:
train_loss_list = []
test_loss_list = []
train_acc_list = []
test_acc_list = []
best_loss = 100
best_acc = 0
for i in range(epoch):
    print("##############################################")
    print("Epoch:{}/{}".format(i+1,epoch))
    train_loss = 0
    test_loss = 0
    train_acc = 0
    test_acc = 0

    literal_listener.train()
    speaker.train()
    #print("Start Training")
    for cols,label,lang in train_batch:
        cols, lang = cols.to(device).to(torch.float), lang.to(device)
        label = label.to(device).to(torch.float)
        optimizer.zero_grad()
        y_prob,target = speaker(cols, label)
        # from model
        target_col, target_shape = target[:,:6], target[:,6:]
        target_cols = [list(COLOR.keys())[int(torch.argmax(col_vec))] for col_vec in target_col]
        target_shapes = [list(SHAPE.keys())[int(torch.argmax(shape_vec))] for shape_vec in target_shape]
        target_utter = [c+" "+s for c,s in zip(target_cols,target_shapes)]
        # from encoder directly
        target_col, target_shape = cnn_color_encoder(cols[:,0]), cnn_shape_encoder(cols[:,0])
        target_cols = [list(COLOR.keys())[int(torch.argmax(col_vec))] for col_vec in target_col]
        target_shapes = [list(SHAPE.keys())[int(torch.argmax(shape_vec))] for shape_vec in target_shape]
        direct_target_utter = [c+" "+s for c,s in zip(target_cols,target_shapes)]
        # for classification task
        class_label = torch.vstack(tuple([lang2class(" ".join([i2w[int(idx)] for idx in l])) for l in lang])).to(device)
        clas_loss = criterion(y_prob, class_label.to(device))
        # for literal lisner loss
        #lis_labels = literal_listener(cols, output_lang)
        #lis_loss = criterion(lis_labels,label)
        loss = clas_loss #+ lis_loss
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        #pred_labels = torch.argmax(lis_labels,dim=1)
        #correct_labels = torch.zeros(pred_labels.shape[0])+2
        #train_acc += sum(correct_labels.to(device)==pred_labels)/len(correct_labels)
        train_acc += sum(torch.argmax(class_label,dim=1)==torch.argmax(y_prob,dim=1))/len(class_label)
    batch_train_loss = train_loss/len(train_batch)
    batch_train_acc = train_acc/len(train_batch)

    speaker.eval()
    #print("Start Evaluation")
    for cols,label,lang in test_batch:
        cols, lang = cols.to(device).to(torch.float), lang.to(device)
        label = label.to(device).to(torch.float)
        optimizer.zero_grad()
        y_prob,target = speaker(cols, label)
        # from model
        target_col, target_shape = target[:,:6], target[:,6:]
        target_cols = [list(COLOR.keys())[int(torch.argmax(col_vec))] for col_vec in target_col]
        target_shapes = [list(SHAPE.keys())[int(torch.argmax(shape_vec))] for shape_vec in target_shape]
        target_utter = [c+" "+s for c,s in zip(target_cols,target_shapes)]
        # from encoder directly
        target_col, target_shape = cnn_color_encoder(cols[:,0]), cnn_shape_encoder(cols[:,0])
        target_cols = [list(COLOR.keys())[int(torch.argmax(col_vec))] for col_vec in target_col]
        target_shapes = [list(SHAPE.keys())[int(torch.argmax(shape_vec))] for shape_vec in target_shape]
        direct_target_utter = [c+" "+s for c,s in zip(target_cols,target_shapes)]
        # for classification task
        class_label = torch.vstack(tuple([lang2class(" ".join([i2w[int(idx)] for idx in l])) for l in lang])).to(device)
        loss = criterion(y_prob, class_label)
        loss.backward()
        optimizer.step()
        test_loss += loss.item()
        test_acc += sum(torch.argmax(class_label,dim=1)==torch.argmax(y_prob,dim=1))/len(class_label)
    batch_test_loss = test_loss/len(test_batch)
    batch_test_acc = test_acc/len(test_batch)
    
    if i%1==0: check_batch(cols,lang,label,y_prob,target_utter,direct_target_utter,id=np.random.randint(len(cols)))
    print("Train Loss:{:.2E}, Test Loss:{:.2E}".format(batch_train_loss,batch_test_loss))
    print("Train Acc:{:.2E}, Test Acc:{:.2E}".format(batch_train_acc,batch_test_acc))
    train_loss_list.append(batch_train_loss)
    test_loss_list.append(batch_test_loss)
    train_acc_list.append(batch_train_acc)
    test_acc_list.append(batch_test_acc)
    if batch_test_loss < best_loss:
        print("Best loss saved ...")
        torch.save(speaker.to(device).state_dict(),"model_params/shapeworld_S1-class-ext-deep-dim-with-label_lis=emb-rnn_CS-CNN-encoder_more-data_best_loss.pth")
        best_loss = batch_test_loss
    if batch_test_acc > best_acc:
        print("Best acc saved ...")
        torch.save(speaker.to(device).state_dict(),"model_params/shapeworld_S1-class-ext-deep-dim-with-label_lis=emb-rnn_CS-CNN-encoder_more-data_best_acc.pth")
        best_acc = batch_test_acc


In [None]:
# visualization
plt.figure()
plt.title("Train and Test Loss")
plt.xlabel("epoch")
plt.ylabel("Literal_Listener_loss")
plt.plot(range(1,epoch+1),train_loss_list,"b-",label="train_loss")
plt.plot(range(1,epoch+1),test_loss_list,"r--",label="test_loss")
plt.legend()
plt.show()
train_acc_list = [float(acc) for acc in train_acc_list]
test_acc_list = [float(acc) for acc in test_acc_list]
plt.figure()
plt.title("Train and Test Accuracy")
plt.xlabel("epoch")
plt.ylabel("Accuracy")
plt.plot(range(1,epoch+1),train_acc_list,"b-",label="train_acc")
plt.plot(range(1,epoch+1),test_acc_list,"r--",label="test_acc")
plt.legend()
plt.show()

In [None]:
#torch.save(speaker.to(device).state_dict(),"model_params/shapeworld_S1_class_final-30-epoch.pth")

## Accuracy test

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[-1]))
print(d["imgs"].shape)
print(d["labels"].shape)
print(d["langs"].shape)
check_raw_data(d["imgs"],d["labels"],d["langs"],id=1)
eval_batch = DataLoader(ShapeWorld(d, vocab), batch_size=32, shuffle=False)

In [None]:
import re
def check_data(imgs, label, g_langs, c_langs, cap="Correct"):
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 4))
    title = re.sub(r"<sos>|<eos>","","Generated: "+" ".join(g_langs)+"\n\nCorrect: "+" ".join(c_langs))
    fig.suptitle(title)
    t_idx = int(torch.argmax(label))
    for i,(l,img) in enumerate(zip(label,imgs)):
        img = img.transpose(2,0)
        axes[i].imshow(img)
        if t_idx==i: axes[i].set_title(cap)
    plt.show()

In [None]:
speaker.load_state_dict(torch.load("model_params/shapeworld_S1-class-ext-deep-dim-with-label_lis=emb-rnn_CS-CNN-encoder_more-data_best_acc.pth",map_location=device))
speaker.to(device)
print("model loaded successfully")

### Classification accuracy

In [None]:
speaker.eval()

losss = []
accs = []
for i,(cols,label,lang) in enumerate(eval_batch):
    cols, lang = cols.to(device).to(torch.float), lang.to(device)
    label = label.to(device).to(torch.float)
    y_prob,target = speaker(cols, label)
    class_label = torch.vstack(tuple([lang2class(" ".join([i2w[int(idx)] for idx in l])) for l in lang])).to(device)
    acc = (sum(torch.argmax(class_label,dim=1)==torch.argmax(y_prob,dim=1))/len(class_label)).item()
    accs.append(acc)
    if i%10 == 0:
        imgs = cols[0].to("cpu")
        c_langs = [i2w[idx] for idx in lang[0].to("cpu").tolist()]
        g_langs = utter_list[int(torch.argmax(y_prob[0]))].split(" ")
        label = label[0]
        check_data(imgs, label, g_langs, c_langs)

print("Accuracy:,",np.mean(accs))

### L0 accuracy

In [None]:
def onehot_to_2word(y_vec):
    utters = utter_list[int(torch.argmax(y_vec))].split(" ")
    if len(utters) == 2:
        return [w2i[w] for w in (["<sos>"]+utters+["<eos>"])]
    else:
        return [w2i[w] for w in (["<sos>"]+utters+["<PAD>","<eos>"])]

In [None]:
literal_listener.to(device)
literal_listener.eval()
speaker.eval()

losss = []
accs = []
for i,(cols,label,lang) in enumerate(eval_batch):
    cols, lang = cols.to(device).to(torch.float), lang.to(device)
    label = label.to(device).to(torch.float)
    y_prob,target = speaker(cols, label)
    output_lang = torch.tensor(np.array([onehot_to_2word(y) for y in y_prob])).to(device)
    #print(output_lang.shape)
    lis_labels = literal_listener(cols, output_lang)
    loss = criterion(lis_labels,label)
    #print("Loss: ",loss.item())
    losss.append(loss.item())
    pred_labels = torch.argmax(lis_labels,dim=1)
    correct_labels = torch.zeros(cols.shape[0])
    acc = sum(correct_labels.to(device)==pred_labels)/len(correct_labels)
    #print("Accuracy:",acc)
    accs.append(acc.item())
    if i%10 == 0:
        id = np.random.randint(len(cols))
        imgs = cols[id].to("cpu")
        c_langs = [i2w[idx] for idx in lang[id].to("cpu").tolist()]
        g_langs = [i2w[idx] for idx in output_lang[id].to("cpu").tolist()]
        label = label[0]
        lis_label = lis_labels[id]
        print(lis_label.to("cpu").tolist())
        check_data(imgs, lis_label, g_langs, c_langs, cap="L0 prediction")

print("Accuracy:,",np.mean(accs))


### L1 accuracy

In [None]:
def langTensor2idx(langt):
    utter = " ".join([i2w[int(idx)] for idx in langt])
    utter = " ".join([w for w in re.sub(r"<sos>|<eos>|<PAD>|<UNK>","",utter).split(" ") if w])
    #print(utter)
    return utter_list.index(utter)

losss = []
accs = []
for i,(cols,label,lang) in enumerate(eval_batch):
    cols, lang = cols.to(device).to(torch.float), lang.to(device)
    label = label.to(device).to(torch.float)
    literal_listener.eval()
    speaker.eval()
    y_prob1,target = speaker(cols, label)
    output_lang = torch.tensor(np.array([onehot_to_2word(y) for y in y_prob1])).to(device)
    # for 2nd image
    label02 = torch.zeros_like(label)
    label02[:,1] = 1.0
    y_prob2,target = speaker(cols, label02)
    output_lang2 = torch.tensor(np.array([onehot_to_2word(y) for y in y_prob2])).to(device)
    # for 3rd image
    label03 = torch.zeros_like(label)
    label03[:,2] = 1.0
    y_prob3,target = speaker(cols, label03)
    output_lang3 = torch.tensor(np.array([onehot_to_2word(y) for y in y_prob3])).to(device)
    
    prob01 = [y[langTensor2idx(l)].to("cpu").detach() for y,l in zip(y_prob1,lang)]
    prob02 = [y[langTensor2idx(l)].to("cpu").detach() for y,l in zip(y_prob2,lang)]
    prob03 = [y[langTensor2idx(l)].to("cpu").detach() for y,l in zip(y_prob3,lang)]
    probs = torch.tensor(np.array([prob01,prob02,prob03])).transpose(0,1)
    #print(probs)
    loss = criterion(probs.to(device),label)
    losss.append(loss.item())
    pred_labels = torch.argmax(probs,dim=1)
    correct_labels = torch.zeros(cols.shape[0])
    acc = sum(correct_labels==pred_labels)/len(correct_labels)
    accs.append(acc.item())
    
    if i%10 == 0:
        id = 0 #np.random.randint(len(cols))
        imgs = cols[0].to("cpu")
        c_langs = [i2w[idx] for idx in lang[id].to("cpu").tolist()]
        g_langs = [i2w[idx] for idx in output_lang[id].to("cpu").tolist()] \
                + ["|"]+ [i2w[idx] for idx in output_lang2[id].to("cpu").tolist()] \
                + ["|"]+ [i2w[idx] for idx in output_lang3[id].to("cpu").tolist()]
        label = label[id]
        prob = probs[id]
        print(prob.to("cpu").tolist())
        check_data(imgs, prob, g_langs, c_langs, cap="L1 prediction")
        print(torch.exp(probs[id]))
        print(torch.argmax(probs[id]))
        #print("Loss: ",loss.item())

print("Accuracy:,",np.mean(accs))