In [None]:
from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from shapeworld_data import load_raw_data, get_vocab, ShapeWorld

In [None]:
GPT = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ",device)

## Prepare data

In [None]:
def check_raw_data(imgs, labels, langs, id=0):
    data = list(zip(imgs,labels,langs))
    img_list,label,lang = data[id]
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 2))
    fig.suptitle(" ".join(lang))
    for i,(l,img) in enumerate(zip(label,img_list)):
        img = img.transpose((2,1,0))
        axes[i].imshow(img)
        if l==1: axes[i].set_title("Correct")
    plt.show()

In [None]:
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")

def sentence2index(sentence):
    tokenized = tokenizer.encode(sentence)
    #print(tokenized)
    return tokenized

In [None]:
root = Path(os.path.abspath('')).parent.parent.absolute()
data_path = os.path.join(root,"data\shapeworld_np")
print(data_path)
data_list = os.listdir(data_path)
print(data_list)

### Generating vocab_dict

In [None]:
vocab = get_vocab([os.path.join(data_path,d) for d in data_list])
print(vocab["w2i"])

COLOR = {"white":[1,0,0,0,0,0], "green":[0,1,0,0,0,0], "gray":[0,0,1,0,0,0], "yellow":[0,0,0,1,0,0], "red":[0,0,0,0,1,0], "blue":[0,0,0,0,0,1], "other":[0,0,0,0,0,0]}
SHAPE = {"shape":[0,0,0,0], "square":[1,0,0,0], "circle":[0,1,0,0], "rectangle":[0,0,1,0], "ellipse":[0,0,0,1]}

print("Generating Vocab_dict from GPT tokenizer ...")
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
gpt_vocab_dict = tokenizer.get_vocab()
print("Length of the GPT Vocab list is ",len(gpt_vocab_dict.keys()))

PAD = 15636
SOS= EOS = UNK = 50256
original_PAD = 0
original_SOS = 1
original_EOS = 2
original_UNK = 3

w2i = vocab["w2i"]
i2w = vocab["i2w"]

### Prepapre the data loader

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[0]))
imgs = d["imgs"]
labels = d["labels"]
langs = d["langs"]
check_raw_data(d["imgs"],d["labels"],d["langs"])
for i in range(1,5):
    d = load_raw_data(os.path.join(data_path,data_list[i]))
    imgs = np.vstack((imgs,d["imgs"]))
    labels = np.vstack((labels,d["labels"]))
    langs = np.hstack((langs,d["langs"]))
    #check_raw_data(d["imgs"],d["labels"],d["langs"])

imgs_data_tensor = torch.tensor(imgs,dtype=torch.float)
label_data_tensor = torch.tensor(labels)
context_id_data = list(map(sentence2index,langs))
max_context_len = np.max([len(c) for c in context_id_data])
padded_context_data = torch.tensor(np.array([[SOS]+c+[EOS]+[PAD]*(max_context_len-len(c)) for c in context_id_data]))   # <sos>+context+<eos>+<pad>*
print(imgs_data_tensor.shape, label_data_tensor.shape, padded_context_data.shape)

data = [(img,u,l) for img,l,u in zip(imgs_data_tensor,label_data_tensor,padded_context_data)]
test_split = 1000
train_data, test_data = data[:-test_split], data[-test_split:]
print("Train, Test data length = ",len(train_data),",",len(test_data))

train_batch = DataLoader(dataset=train_data,batch_size=32,shuffle=True,num_workers=0)
test_batch = DataLoader(dataset=test_data,batch_size=32,shuffle=False,num_workers=0)

## Model

### Encoder

In [None]:
from literal_listener_shapeworld import CNN_encoder

class Imgs_emb_DeepSet(nn.Module):
    def __init__(self, input_size=10, output_size=20):
        super(Imgs_emb_DeepSet, self).__init__()
        self.linear1 = nn.Linear(input_size, output_size)
        self.linear2 = nn.Linear(input_size, output_size)
        self.linear3 = nn.Linear(output_size, output_size)
    
    def forward(self, img_emb1, img_emb2):
        img_embs = F.relu(self.linear1(img_emb1)) + F.relu(self.linear2(img_emb2))
        img_embs = self.linear3(img_embs)
        return img_embs


class CS_CNN_Encoder(nn.Module):
    def __init__(self, input_size=10, output_size=10):
        super(CS_CNN_Encoder, self).__init__()
        self.cnn_color_encoder = CNN_encoder(6)
        self.cnn_color_encoder.load_state_dict(torch.load("model_params/cnn_color_model_best-loss.pth",map_location=device))
        for params in self.cnn_color_encoder.parameters(): params.requires_grad = False
        self.cnn_shape_encoder = CNN_encoder(4)
        self.cnn_shape_encoder.load_state_dict(torch.load("model_params/cnn_shape_model_best-loss.pth",map_location=device))
        for params in self.cnn_shape_encoder.parameters(): params.requires_grad = False
        self.deepset_size = 6+4
        self.deepset = Imgs_emb_DeepSet(self.deepset_size, self.deepset_size)
        self.linear = nn.Linear(input_size+self.deepset_size, output_size)

    def get_feat_emb(self,feat):
        col_embs = self.cnn_color_encoder(feat)
        shape_embs = self.cnn_shape_encoder(feat)
        img_embs = torch.hstack((col_embs,shape_embs))
        return img_embs
    
    def forward(self,feats,labels):
        idxs = [0,1,2]
        target_idx = int(torch.argmax(labels))
        idxs.remove(target_idx)
        other_idx1,other_idx2 = idxs[0],idxs[1]
        target_img,other_img1,other_img2 = feats[:,target_idx], feats[:,other_idx1], feats[:,other_idx2]
        target_embs = self.get_feat_emb(target_img)
        other_embs1 = self.get_feat_emb(other_img1)
        other_embs2 = self.get_feat_emb(other_img2)                 # (batch_size,10)
        other_embs = F.relu(self.deepset(other_embs1,other_embs2))  # (batch_size,10)
        embs = torch.hstack((target_embs,other_embs))               # (batch_size,20)
        feat = self.linear(embs)                                    # (batch_size,output_size)
        return feat

### Speaker Encoder-Decoder model

In [None]:
# Define the encoder-decoder model
from transformers import EncoderDecoderModel

class S1_EncoderDecoder(torch.nn.Module):
    def __init__(self, input_size, hidden_size=768):
        super(S1_EncoderDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.encoder = CS_CNN_Encoder(input_size, hidden_size)
        self.decoder = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "distilgpt2").decoder

    def forward(self, feats, labels, langs):
        batch_size = len(feats)
        encoder_hidden = self.encoder(feats, labels)
        decoder_hidden = encoder_hidden.reshape(batch_size,1,self.hidden_size)
        decoder_input = langs[:,:-1]
        decoder_output = self.decoder(input_ids=decoder_input, encoder_hidden_states=decoder_hidden)
        return decoder_output[0]
    
    def generate(self,feats,labels,max_len=5,temperature=0.7):
        batch_size = len(feats)
        encoder_hidden = self.encoder(feats, labels)
        decoder_hidden = encoder_hidden.reshape(batch_size,1,self.hidden_size)
        sos = "<|endoftext|>"
        generated = torch.tensor(tokenizer.encode(sos)*batch_size).unsqueeze(1).to(decoder_hidden.device)
        probs_list = torch.zeros(batch_size,50257)
        probs_list[:,SOS] = 1.0
        probs_list = probs_list.unsqueeze(1).to(decoder_hidden.device)
        for i in range(max_len):
            #print(generated.shape)
            decoder_output = self.decoder(input_ids=generated, encoder_hidden_states=decoder_hidden)
            logits = decoder_output[0][:,-1,:]/temperature
            probs = F.softmax(logits,dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
            probs_list = torch.cat((probs_list,probs.unsqueeze(1)),dim=1)
        return generated,probs_list

## Training

### Setting

In [None]:
import re
def check_data(img_list,c_lang,g_lang,lis_label,cscnn=None,gg_lang=None):
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 4))
    context = "Correct: "+" ".join(c_lang)
    if cscnn is not None: context += "\nCSCNN:"+cscnn
    context += "\nGenerated:"+" ".join(g_lang)
    if gg_lang is not None: context += "\n Generated02:"+re.sub(r"<\|endoftext\|>|pad"," ",gg_lang)
    fig.suptitle(context)
    for i,(l,img) in enumerate(zip(label,img_list)):
        img = img.transpose(2,0).to("cpu").detach().numpy()
        axes[i].imshow(img)
        if torch.argmax(lis_label)==i:axes[i].set_title("Listener Prediction")
    plt.show()

In [None]:
from literal_listener_shapeworld import ShapeWorld_RNN_L0, CNN_encoder

feat_dim = 10
speaker = S1_EncoderDecoder(input_size=feat_dim)
speaker.to(device)

literal_listener = ShapeWorld_RNN_L0(len(w2i)).to(device)
literal_listener.load_state_dict(torch.load("model_params\shapeworld_rnn_full-data_100epoch_l0_last.pth",map_location=device))

cnn_color_encoder = CNN_encoder(6).to(device)
cnn_color_encoder.load_state_dict(torch.load("model_params/cnn_color_model_best-loss.pth",map_location=device))
for params in cnn_color_encoder.parameters(): params.requires_grad = False
cnn_shape_encoder = CNN_encoder(4).to(device)
cnn_shape_encoder.load_state_dict(torch.load("model_params/cnn_shape_model_best-loss.pth",map_location=device))
for params in cnn_shape_encoder.parameters(): params.requires_grad = False

optimizer = optim.Adam(speaker.parameters(),lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=PAD)

max_len = 5

epoch = 100

### Start

In [None]:
from tqdm import tqdm

train_loss_list = []
test_loss_list = []
best_loss = 100
for i in range(epoch):
    print("##############################################")
    print("Epoch:{}/{}".format(i+1,epoch))
    train_loss = 0
    test_loss = 0

    speaker.train()
    for cols,lang,label in tqdm(train_batch):
        cols, lang, label = cols.to(device), lang.type(torch.LongTensor).to(device), label.to(device)
        optimizer.zero_grad()
        output = speaker(cols, label, lang)
        # Compute the loss
        output_view = output.view(-1, output.shape[-1])
        target = lang[:,1:].reshape(-1)
        lang_loss = criterion(output_view, target)
        lang_loss.backward(retain_graph=True)
        optimizer.step()
        train_loss += lang_loss.item()
        #break
    batch_train_loss = train_loss/len(train_batch)

    accs = []
    speaker.eval()
    with torch.no_grad():
        for cols,lang,label in tqdm(test_batch):
            cols, lang, label = cols.to(device), lang.type(torch.LongTensor).to(device), label.to(device)
            output = speaker(cols, label, lang)
            target_col, target_shape = cnn_color_encoder(cols[:,0]), cnn_shape_encoder(cols[:,0])
            target_cols = [list(COLOR.keys())[int(torch.argmax(col_vec))] for col_vec in target_col]
            target_shapes = [list(SHAPE.keys())[int(torch.argmax(shape_vec))] for shape_vec in target_shape]
            direct_target_utter = [c+" "+s for c,s in zip(target_cols,target_shapes)]
            output_view = output.view(-1, output.shape[-1])
            target = lang[:,1:].reshape(-1)
            lang_loss = criterion(output_view, target)
            test_loss += lang_loss.item()
            generated,lang_probs = speaker.generate(cols,label,max_len=max_len)
            #break
        batch_test_loss = test_loss/len(test_batch)

    print("Train Loss:{:.2E}, Test Loss:{:.2E}".format(batch_train_loss,batch_test_loss))
    train_loss_list.append(batch_train_loss)
    test_loss_list.append(batch_test_loss)
    if batch_test_loss < best_loss:
        print("Best loss saved ...")
        torch.save(speaker.to(device).state_dict(),"model_params/shapeworld_S1-GPT2-decoder_loss=Lang_best-loss_100epoch.pth")
        best_loss = batch_test_loss
    if i%10 == 0:
        id = np.random.randint(len(cols))
        cols = cols[id].to("cpu")
        c_langs = tokenizer.decode([idx for idx in lang[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ")
        g_langs = tokenizer.decode([idx for idx in output.argmax(2)[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ")
        lis_label = torch.zeros(3)
        lis_label[0] = 1.0
        check_data(cols, c_langs, g_langs, lis_label, cscnn=direct_target_utter[id],gg_lang=tokenizer.decode(generated[id]))
    #break

## Accuracy

In [None]:
from tqdm import tqdm
from literal_listener_shapeworld import ShapeWorld_RNN_L0, CNN_encoder

feat_dim = 10
speaker = S1_EncoderDecoder(input_size=feat_dim).to(device)
best = "model_params/shapeworld_S1-GPT2-decoder_loss=Lang_best-loss_100epoch.pth"
speaker.load_state_dict(torch.load(best,map_location=device))
speaker.to(device)

literal_listener = ShapeWorld_RNN_L0(len(w2i)).to(device)
literal_listener.load_state_dict(torch.load("model_params\shapeworld_rnn_full-data_100epoch_l0_last.pth",map_location=device))

print("Model loaded ...")

optimizer = optim.Adam(speaker.parameters(),lr=0.001)
criterion = nn.CrossEntropyLoss()

max_len = 5


### L0 accuracy

In [None]:
col_list = list(COLOR.keys())
col_list[-1] = ""
shape_list = list(SHAPE.keys())
utter_list = [" ".join([w for w in (c+" "+s).split(" ") if w]) for c in col_list for s in shape_list+[""]]
gpt_utter_list = ["".join([w for w in (c+" "+s).split(" ") if w]) for c in col_list for s in shape_list+[""]]
vocab2gpt = {g:u for u,g in zip(utter_list,gpt_utter_list)}
for g,u in vocab2gpt.items():
    print(u," : ",g)

In [None]:
from nltk.tokenize import word_tokenize
from functools import reduce

def decode_gpt_vocab(w):
    if w in w2i.keys(): 
        return [w2i[w]]
    elif w in vocab2gpt.keys():
        return [w2i[t] for t in vocab2gpt[w].split(" ")]
    else:
        return [original_UNK]

def gpt_lang2L0_lang(generated_langs):
    langs = [tokenizer.decode([idx for idx in generated if idx not in [PAD,SOS,EOS]]) for generated in generated_langs]
    tokens = []
    for l in langs:
        decoded = [decode_gpt_vocab(w) for w in word_tokenize(l)]+[[],[]]
        tokens.append(list(reduce(lambda x,y:x+y,decoded)))
    max_tokens_len = max([len(t) for t in tokens])
    padded_tokens = torch.tensor(np.array([[original_SOS]+ts+[original_EOS]+[original_PAD]*(max_tokens_len-len(ts)) for ts in tokens]))
    return padded_tokens

In [None]:
accs = []
speaker.eval()
#speaker = speaker.cpu()
with torch.no_grad():
    for i,(cols,lang,label) in enumerate(test_batch):
        cols, lang, label = cols.to(device), lang.to(device), label.to(device)
        generated_lang, lang_probs = speaker.generate(cols,label,max_len=max_len)
        output_lang = gpt_lang2L0_lang(generated_lang).to(device)
        literal_listener.eval()
        #literal_listener = literal_listener.cpu()
        lis_labels = literal_listener(cols, output_lang)
        pred_labels = torch.argmax(lis_labels,dim=1)
        correct_labels = torch.zeros(cols.shape[0])
        acc = sum(correct_labels.to(device)==pred_labels)/len(correct_labels)
        accs.append(acc.item())
        if i%10 == 0:
            print(i+1,"/",len(test_batch))
            id = np.random.randint(len(cols))
            cols = cols[id].to("cpu")
            print([i2w[int(idx)] for idx in output_lang[id]])
            c_langs = tokenizer.decode([idx for idx in lang[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ")
            g_langs = tokenizer.decode([idx for idx in generated_lang[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ")
            label = label[id]
            #print(int(torch.argmax(lis_labels[id])))
            check_data(cols, c_langs, g_langs, lis_labels[id])

print("Accuracy: ",np.mean(accs))

### L1 accuracy

In [None]:
losss = []
accs = []
speaker.eval()

for i,(cols,lang,label) in enumerate(test_batch):
    cols, lang, label = cols.to(device).to(torch.float), lang.to(device), label.to(device).to(torch.float)
    # for 1st image
    label01 = torch.zeros_like(label)
    label01[:,0] = 1.0
    generated_lang1, lang_probs1 = speaker.generate(cols,label01,max_len=max_len)
    # for 2nd image
    label02 = torch.zeros_like(label)
    label02[:,1] = 1.0
    generated_lang2, lang_probs2 = speaker.generate(cols,label02,max_len=max_len)
    # for 3rd image
    label03 = torch.zeros_like(label)
    label03[:,2] = 1.0
    generated_lang3, lang_probs3 = speaker.generate(cols,label03,max_len=max_len)
    
    # compute the probability
    prob01 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_probs1,lang))]
    prob01_sums = list(map(sum,prob01))
    prob02 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_probs2,lang))]
    prob02_sums = list(map(sum,prob02))
    prob03 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_probs3,lang))]
    prob03_sums = list(map(sum,prob03))
    probs = F.softmax(torch.tensor(np.array([prob01_sums,prob02_sums,prob03_sums])).transpose(0,1),dim=-1)
    #print(probs)
    loss = criterion(probs.to(device),label)
    losss.append(loss.item())
    pred_labels = torch.argmax(probs,dim=1)
    correct_labels = torch.zeros(cols.shape[0])
    acc = sum(correct_labels==pred_labels)/len(correct_labels)
    accs.append(acc.item())
    
    if i%10 == 0:
        id = np.random.randint(len(cols))
        imgs = cols[id].to("cpu")
        c_langs = tokenizer.decode([idx for idx in lang[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ")
        g_langs = tokenizer.decode([idx for idx in generated_lang1[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ")\
                + ["|"]+ tokenizer.decode([idx for idx in generated_lang2[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ") \
                + ["|"]+ tokenizer.decode([idx for idx in generated_lang3[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ")
        label = label[id]
        check_data(imgs, c_langs, g_langs, probs[id])
        print(torch.exp(probs[id]))
        print(torch.where(torch.exp(probs[id])>0.1,1,0))

print("Loss: ",np.mean(losss))
print("Accuracy: ",np.mean(accs))